Recovery Mode (#7559)
* Initial work * rework * s/dr/recovery * Add sys/raw support to recovery mode (#7577) * Factor the raw paths out so they can be run with a SystemBackend. # Conflicts: # vault/logical_system.go * Add handleLogicalRecovery which is like handleLogical but is only sufficient for use with the sys-raw endpoint in recovery mode. No authentication is done yet. * Integrate with recovery-mode. We now handle unauthenticated sys/raw requests, albeit on path v1/raw instead v1/sys/raw. * Use sys/raw instead raw during recovery. * Don't bother persisting the recovery token. Authenticate sys/raw requests with it. * RecoveryMode: Support generate-root for autounseals (#7591) * Recovery: Abstract config creation and log settings * Recovery mode integration test. (#7600) * Recovery: Touch up (#7607) * Recovery: Touch up * revert the raw backend creation changes * Added recovery operation token prefix * Move RawBackend to its own file * Update API path and hit it using CLI flag on generate-root * Fix a panic triggered when handling a request that yields a nil response. (#7618) * Improve integ test to actually make changes while in recovery mode and verify they're still there after coming back in regular mode. * Refuse to allow a second recovery token to be generated. * Resize raft cluster to size 1 and start as leader (#7626) * RecoveryMode: Setup raft cluster post unseal (#7635) * Setup raft cluster post unseal in recovery mode * Remove marking as unsealed as its not needed * Address review comments * Accept only one seal config in recovery mode as there is no scope for migration
This commit is contained in:
parent
ffb699e48c
commit
0d077d7945
|
@ -10,6 +10,10 @@ func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, err
|
|||
return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRecoveryOperationTokenStatus() (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootStatusCommon("/v1/sys/generate-recovery-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) {
|
||||
r := c.c.NewRequest("GET", path)
|
||||
|
||||
|
@ -34,6 +38,10 @@ func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootSta
|
|||
return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey)
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRecoveryOperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootInitCommon("/v1/sys/generate-recovery-token/attempt", otp, pgpKey)
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"otp": otp,
|
||||
|
@ -66,6 +74,10 @@ func (c *Sys) GenerateDROperationTokenCancel() error {
|
|||
return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRecoveryOperationTokenCancel() error {
|
||||
return c.generateRootCancelCommon("/v1/sys/generate-recovery-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootCancelCommon(path string) error {
|
||||
r := c.c.NewRequest("DELETE", path)
|
||||
|
||||
|
@ -86,6 +98,10 @@ func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRoot
|
|||
return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce)
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRecoveryOperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootUpdateCommon("/v1/sys/generate-recovery-token/update", shard, nonce)
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"key": shard,
|
||||
|
|
|
@ -23,6 +23,14 @@ import (
|
|||
var _ cli.Command = (*OperatorGenerateRootCommand)(nil)
|
||||
var _ cli.CommandAutocomplete = (*OperatorGenerateRootCommand)(nil)
|
||||
|
||||
type generateRootKind int
|
||||
|
||||
const (
|
||||
generateRootRegular generateRootKind = iota
|
||||
generateRootDR
|
||||
generateRootRecovery
|
||||
)
|
||||
|
||||
type OperatorGenerateRootCommand struct {
|
||||
*BaseCommand
|
||||
|
||||
|
@ -35,6 +43,7 @@ type OperatorGenerateRootCommand struct {
|
|||
flagNonce string
|
||||
flagGenerateOTP bool
|
||||
flagDRToken bool
|
||||
flagRecoveryToken bool
|
||||
|
||||
testStdin io.Reader // for tests
|
||||
}
|
||||
|
@ -143,6 +152,16 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets {
|
|||
"tokens.",
|
||||
})
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "recovery-token",
|
||||
Target: &c.flagRecoveryToken,
|
||||
Default: false,
|
||||
EnvVar: "",
|
||||
Completion: complete.PredictNothing,
|
||||
Usage: "Set this flag to do generate root operations on Recovery Operational " +
|
||||
"tokens.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "otp",
|
||||
Target: &c.flagOTP,
|
||||
|
@ -200,43 +219,60 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
if c.flagDRToken && c.flagRecoveryToken {
|
||||
c.UI.Error("Both -recovery-token and -dr-token flags are set")
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 2
|
||||
}
|
||||
|
||||
kind := generateRootRegular
|
||||
switch {
|
||||
case c.flagDRToken:
|
||||
kind = generateRootDR
|
||||
case c.flagRecoveryToken:
|
||||
kind = generateRootRecovery
|
||||
}
|
||||
|
||||
switch {
|
||||
case c.flagGenerateOTP:
|
||||
otp, code := c.generateOTP(client, c.flagDRToken)
|
||||
otp, code := c.generateOTP(client, kind)
|
||||
if code == 0 {
|
||||
return PrintRaw(c.UI, otp)
|
||||
}
|
||||
return code
|
||||
case c.flagDecode != "":
|
||||
return c.decode(client, c.flagDecode, c.flagOTP, c.flagDRToken)
|
||||
return c.decode(client, c.flagDecode, c.flagOTP, kind)
|
||||
case c.flagCancel:
|
||||
return c.cancel(client, c.flagDRToken)
|
||||
return c.cancel(client, kind)
|
||||
case c.flagInit:
|
||||
return c.init(client, c.flagOTP, c.flagPGPKey, c.flagDRToken)
|
||||
return c.init(client, c.flagOTP, c.flagPGPKey, kind)
|
||||
case c.flagStatus:
|
||||
return c.status(client, c.flagDRToken)
|
||||
return c.status(client, kind)
|
||||
default:
|
||||
// If there are no other flags, prompt for an unseal key.
|
||||
key := ""
|
||||
if len(args) > 0 {
|
||||
key = strings.TrimSpace(args[0])
|
||||
}
|
||||
return c.provide(client, key, c.flagDRToken)
|
||||
return c.provide(client, key, kind)
|
||||
}
|
||||
}
|
||||
|
||||
// generateOTP generates a suitable OTP code for generating a root token.
|
||||
func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bool) (string, int) {
|
||||
func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, kind generateRootKind) (string, int) {
|
||||
f := client.Sys().GenerateRootStatus
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenStatus
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenStatus
|
||||
}
|
||||
|
||||
status, err := f()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
|
||||
|
@ -272,7 +308,7 @@ func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bo
|
|||
}
|
||||
|
||||
// decode decodes the given value using the otp.
|
||||
func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, drToken bool) int {
|
||||
func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, kind generateRootKind) int {
|
||||
if encoded == "" {
|
||||
c.UI.Error("Missing encoded value: use -decode=<string> to supply it")
|
||||
return 1
|
||||
|
@ -283,9 +319,13 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st
|
|||
}
|
||||
|
||||
f := client.Sys().GenerateRootStatus
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenStatus
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenStatus
|
||||
}
|
||||
|
||||
status, err := f()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
|
||||
|
@ -327,7 +367,7 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st
|
|||
}
|
||||
|
||||
// init is used to start the generation process
|
||||
func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, drToken bool) int {
|
||||
func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, kind generateRootKind) int {
|
||||
// Validate incoming fields. Either OTP OR PGP keys must be supplied.
|
||||
if otp != "" && pgpKey != "" {
|
||||
c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key")
|
||||
|
@ -336,8 +376,11 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin
|
|||
|
||||
// Start the root generation
|
||||
f := client.Sys().GenerateRootInit
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenInit
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenInit
|
||||
}
|
||||
status, err := f(otp, pgpKey)
|
||||
if err != nil {
|
||||
|
@ -355,10 +398,13 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin
|
|||
|
||||
// provide prompts the user for the seal key and posts it to the update root
|
||||
// endpoint. If this is the last unseal, this function outputs it.
|
||||
func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, drToken bool) int {
|
||||
func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, kind generateRootKind) int {
|
||||
f := client.Sys().GenerateRootStatus
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenStatus
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenStatus
|
||||
}
|
||||
status, err := f()
|
||||
if err != nil {
|
||||
|
@ -437,8 +483,11 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr
|
|||
|
||||
// Provide the key, this may potentially complete the update
|
||||
fUpd := client.Sys().GenerateRootUpdate
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
fUpd = client.Sys().GenerateDROperationTokenUpdate
|
||||
case generateRootRecovery:
|
||||
fUpd = client.Sys().GenerateRecoveryOperationTokenUpdate
|
||||
}
|
||||
status, err = fUpd(key, nonce)
|
||||
if err != nil {
|
||||
|
@ -454,10 +503,13 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr
|
|||
}
|
||||
|
||||
// cancel cancels the root token generation
|
||||
func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) int {
|
||||
func (c *OperatorGenerateRootCommand) cancel(client *api.Client, kind generateRootKind) int {
|
||||
f := client.Sys().GenerateRootCancel
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenCancel
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenCancel
|
||||
}
|
||||
if err := f(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error canceling root token generation: %s", err))
|
||||
|
@ -468,11 +520,15 @@ func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) i
|
|||
}
|
||||
|
||||
// status is used just to fetch and dump the status
|
||||
func (c *OperatorGenerateRootCommand) status(client *api.Client, drToken bool) int {
|
||||
func (c *OperatorGenerateRootCommand) status(client *api.Client, kind generateRootKind) int {
|
||||
f := client.Sys().GenerateRootStatus
|
||||
if drToken {
|
||||
switch kind {
|
||||
case generateRootDR:
|
||||
f = client.Sys().GenerateDROperationTokenStatus
|
||||
case generateRootRecovery:
|
||||
f = client.Sys().GenerateRecoveryOperationTokenStatus
|
||||
}
|
||||
|
||||
status, err := f()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"go.uber.org/atomic"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -99,6 +100,7 @@ type ServerCommand struct {
|
|||
flagConfigs []string
|
||||
flagLogLevel string
|
||||
flagLogFormat string
|
||||
flagRecovery bool
|
||||
flagDev bool
|
||||
flagDevRootTokenID string
|
||||
flagDevListenAddr string
|
||||
|
@ -197,6 +199,13 @@ func (c *ServerCommand) Flags() *FlagSets {
|
|||
Usage: `Log format. Supported values are "standard" and "json".`,
|
||||
})
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "recovery",
|
||||
Target: &c.flagRecovery,
|
||||
Usage: "Enable recovery mode. In this mode, Vault is used to perform recovery actions." +
|
||||
"Using a recovery operation token, \"sys/raw\" API can be used to manipulate the storage.",
|
||||
})
|
||||
|
||||
f = set.NewFlagSet("Dev Options")
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
|
@ -365,6 +374,384 @@ func (c *ServerCommand) AutocompleteFlags() complete.Flags {
|
|||
return c.Flags().Completions()
|
||||
}
|
||||
|
||||
func (c *ServerCommand) parseConfig() (*server.Config, error) {
|
||||
// Load the configuration
|
||||
var config *server.Config
|
||||
for _, path := range c.flagConfigs {
|
||||
current, err := server.LoadConfig(path)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error loading configuration from %s: {{err}}", path), err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = current
|
||||
} else {
|
||||
config = config.Merge(current)
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *ServerCommand) runRecoveryMode() int {
|
||||
config, err := c.parseConfig()
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
// Ensure at least one config was found.
|
||||
if config == nil {
|
||||
c.UI.Output(wrapAtLength(
|
||||
"No configuration files found. Please provide configurations with the " +
|
||||
"-config flag. If you are supplying the path to a directory, please " +
|
||||
"ensure the directory contains files with the .hcl or .json " +
|
||||
"extension."))
|
||||
return 1
|
||||
}
|
||||
|
||||
level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
c.logger = log.New(&log.LoggerOptions{
|
||||
Output: c.logWriter,
|
||||
Level: level,
|
||||
// Note that if logFormat is either unspecified or standard, then
|
||||
// the resulting logger's format will be standard.
|
||||
JSONFormat: logFormat == logging.JSONFormat,
|
||||
})
|
||||
|
||||
logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
if logLevelStr != "" {
|
||||
logLevelString = logLevelStr
|
||||
}
|
||||
|
||||
// create GRPC logger
|
||||
namedGRPCLogFaker := c.logger.Named("grpclogfaker")
|
||||
grpclog.SetLogger(&grpclogFaker{
|
||||
logger: namedGRPCLogFaker,
|
||||
log: os.Getenv("VAULT_GRPC_LOGGING") != "",
|
||||
})
|
||||
|
||||
if config.Storage == nil {
|
||||
c.UI.Output("A storage backend must be specified")
|
||||
return 1
|
||||
}
|
||||
|
||||
if config.DefaultMaxRequestDuration != 0 {
|
||||
vault.DefaultMaxRequestDuration = config.DefaultMaxRequestDuration
|
||||
}
|
||||
|
||||
proxyCfg := httpproxy.FromEnvironment()
|
||||
c.logger.Info("proxy environment", "http_proxy", proxyCfg.HTTPProxy,
|
||||
"https_proxy", proxyCfg.HTTPSProxy, "no_proxy", proxyCfg.NoProxy)
|
||||
|
||||
// Initialize the storage backend
|
||||
factory, exists := c.PhysicalBackends[config.Storage.Type]
|
||||
if !exists {
|
||||
c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type))
|
||||
return 1
|
||||
}
|
||||
if config.Storage.Type == "raft" {
|
||||
if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
|
||||
config.ClusterAddr = envCA
|
||||
}
|
||||
|
||||
if len(config.ClusterAddr) == 0 {
|
||||
c.UI.Error("Cluster address must be set when using raft storage")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
namedStorageLogger := c.logger.Named("storage." + config.Storage.Type)
|
||||
backend, err := factory(config.Storage.Config, namedStorageLogger)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
infoKeys := make([]string, 0, 10)
|
||||
info := make(map[string]string)
|
||||
info["log level"] = logLevelString
|
||||
infoKeys = append(infoKeys, "log level")
|
||||
|
||||
var barrierSeal vault.Seal
|
||||
var sealConfigError error
|
||||
|
||||
if len(config.Seals) == 0 {
|
||||
config.Seals = append(config.Seals, &server.Seal{Type: vaultseal.Shamir})
|
||||
}
|
||||
|
||||
if len(config.Seals) > 1 {
|
||||
c.UI.Error("Only one seal block is accepted in recovery mode")
|
||||
return 1
|
||||
}
|
||||
|
||||
configSeal := config.Seals[0]
|
||||
sealType := vaultseal.Shamir
|
||||
if !configSeal.Disabled && os.Getenv("VAULT_SEAL_TYPE") != "" {
|
||||
sealType = os.Getenv("VAULT_SEAL_TYPE")
|
||||
configSeal.Type = sealType
|
||||
} else {
|
||||
sealType = configSeal.Type
|
||||
}
|
||||
|
||||
var seal vault.Seal
|
||||
sealLogger := c.logger.Named(sealType)
|
||||
seal, sealConfigError = serverseal.ConfigureSeal(configSeal, &infoKeys, &info, sealLogger, vault.NewDefaultSeal(shamirseal.NewSeal(c.logger.Named("shamir"))))
|
||||
if sealConfigError != nil {
|
||||
if !errwrap.ContainsType(sealConfigError, new(logical.KeyNotFoundError)) {
|
||||
c.UI.Error(fmt.Sprintf(
|
||||
"Error parsing Seal configuration: %s", sealConfigError))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if seal == nil {
|
||||
c.UI.Error(fmt.Sprintf(
|
||||
"After configuring seal nil returned, seal type was %s", sealType))
|
||||
return 1
|
||||
}
|
||||
|
||||
barrierSeal = seal
|
||||
|
||||
// Ensure that the seal finalizer is called, even if using verify-only
|
||||
defer func() {
|
||||
err = seal.Finalize(context.Background())
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
Physical: backend,
|
||||
StorageType: config.Storage.Type,
|
||||
Seal: barrierSeal,
|
||||
Logger: c.logger,
|
||||
DisableMlock: config.DisableMlock,
|
||||
RecoveryMode: c.flagRecovery,
|
||||
ClusterAddr: config.ClusterAddr,
|
||||
}
|
||||
|
||||
core, newCoreError := vault.NewCore(coreConfig)
|
||||
if newCoreError != nil {
|
||||
if vault.IsFatalError(newCoreError) {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if err := core.InitializeRecovery(context.Background()); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing core in recovery mode: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Compile server information for output later
|
||||
infoKeys = append(infoKeys, "storage")
|
||||
info["storage"] = config.Storage.Type
|
||||
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
info["cluster address"] = coreConfig.ClusterAddr
|
||||
infoKeys = append(infoKeys, "cluster address")
|
||||
}
|
||||
|
||||
// Initialize the listeners
|
||||
lns := make([]ServerListener, 0, len(config.Listeners))
|
||||
for _, lnConfig := range config.Listeners {
|
||||
ln, _, _, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logWriter, c.UI)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
lns = append(lns, ServerListener{
|
||||
Listener: ln,
|
||||
config: lnConfig.Config,
|
||||
})
|
||||
}
|
||||
|
||||
listenerCloseFunc := func() {
|
||||
for _, ln := range lns {
|
||||
ln.Listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
defer c.cleanupGuard.Do(listenerCloseFunc)
|
||||
|
||||
infoKeys = append(infoKeys, "version")
|
||||
verInfo := version.GetVersion()
|
||||
info["version"] = verInfo.FullVersionNumber(false)
|
||||
if verInfo.Revision != "" {
|
||||
info["version sha"] = strings.Trim(verInfo.Revision, "'")
|
||||
infoKeys = append(infoKeys, "version sha")
|
||||
}
|
||||
|
||||
infoKeys = append(infoKeys, "recovery mode")
|
||||
info["recovery mode"] = "true"
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
c.UI.Output("==> Vault server configuration:\n")
|
||||
for _, k := range infoKeys {
|
||||
c.UI.Output(fmt.Sprintf(
|
||||
"%s%s: %s",
|
||||
strings.Repeat(" ", padding-len(k)),
|
||||
strings.Title(k),
|
||||
info[k]))
|
||||
}
|
||||
c.UI.Output("")
|
||||
|
||||
for _, ln := range lns {
|
||||
handler := vaulthttp.Handler(&vault.HandlerProperties{
|
||||
Core: core,
|
||||
MaxRequestSize: ln.maxRequestSize,
|
||||
MaxRequestDuration: ln.maxRequestDuration,
|
||||
DisablePrintableCheck: config.DisablePrintableCheck,
|
||||
RecoveryMode: c.flagRecovery,
|
||||
RecoveryToken: atomic.NewString(""),
|
||||
})
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: c.logger.StandardLogger(nil),
|
||||
}
|
||||
|
||||
go server.Serve(ln.Listener)
|
||||
}
|
||||
|
||||
if sealConfigError != nil {
|
||||
init, err := core.Initialized(context.Background())
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error checking if core is initialized: %v", err))
|
||||
return 1
|
||||
}
|
||||
if init {
|
||||
c.UI.Error("Vault is initialized but no Seal key could be loaded")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if newCoreError != nil {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! A non-fatal error occurred during initialization. Please " +
|
||||
"check the logs for more information."))
|
||||
c.UI.Warn("")
|
||||
}
|
||||
|
||||
if !c.flagCombineLogs {
|
||||
c.UI.Output("==> Vault server started! Log data will stream in below:\n")
|
||||
}
|
||||
|
||||
c.logGate.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ShutdownCh:
|
||||
c.UI.Output("==> Vault shutdown triggered")
|
||||
|
||||
c.cleanupGuard.Do(listenerCloseFunc)
|
||||
|
||||
if err := core.Shutdown(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err))
|
||||
}
|
||||
|
||||
return 0
|
||||
|
||||
case <-c.SigUSR2Ch:
|
||||
buf := make([]byte, 32*1024*1024)
|
||||
n := runtime.Stack(buf[:], true)
|
||||
c.logger.Info("goroutine trace", "stack", string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *ServerCommand) adjustLogLevel(config *server.Config, logLevelWasNotSet bool) (string, error) {
|
||||
var logLevelString string
|
||||
if config.LogLevel != "" && logLevelWasNotSet {
|
||||
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
|
||||
logLevelString = configLogLevel
|
||||
switch configLogLevel {
|
||||
case "trace":
|
||||
c.logger.SetLevel(log.Trace)
|
||||
case "debug":
|
||||
c.logger.SetLevel(log.Debug)
|
||||
case "notice", "info", "":
|
||||
c.logger.SetLevel(log.Info)
|
||||
case "warn", "warning":
|
||||
c.logger.SetLevel(log.Warn)
|
||||
case "err", "error":
|
||||
c.logger.SetLevel(log.Error)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown log level: %s", config.LogLevel)
|
||||
}
|
||||
}
|
||||
return logLevelString, nil
|
||||
}
|
||||
|
||||
func (c *ServerCommand) processLogLevelAndFormat(config *server.Config) (log.Level, string, bool, logging.LogFormat, error) {
|
||||
// Create a logger. We wrap it in a gated writer so that it doesn't
|
||||
// start logging too early.
|
||||
c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
|
||||
c.logWriter = c.logGate
|
||||
if c.flagCombineLogs {
|
||||
c.logWriter = os.Stdout
|
||||
}
|
||||
var level log.Level
|
||||
var logLevelWasNotSet bool
|
||||
logFormat := logging.UnspecifiedFormat
|
||||
logLevelString := c.flagLogLevel
|
||||
c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
|
||||
switch c.flagLogLevel {
|
||||
case notSetValue, "":
|
||||
logLevelWasNotSet = true
|
||||
logLevelString = "info"
|
||||
level = log.Info
|
||||
case "trace":
|
||||
level = log.Trace
|
||||
case "debug":
|
||||
level = log.Debug
|
||||
case "notice", "info":
|
||||
level = log.Info
|
||||
case "warn", "warning":
|
||||
level = log.Warn
|
||||
case "err", "error":
|
||||
level = log.Error
|
||||
default:
|
||||
return level, logLevelString, logLevelWasNotSet, logFormat, fmt.Errorf("unknown log level: %s", c.flagLogLevel)
|
||||
}
|
||||
|
||||
if c.flagLogFormat != notSetValue {
|
||||
var err error
|
||||
logFormat, err = logging.ParseLogFormat(c.flagLogFormat)
|
||||
if err != nil {
|
||||
return level, logLevelString, logLevelWasNotSet, logFormat, err
|
||||
}
|
||||
}
|
||||
if logFormat == logging.UnspecifiedFormat {
|
||||
logFormat = logging.ParseEnvLogFormat()
|
||||
}
|
||||
if logFormat == logging.UnspecifiedFormat {
|
||||
var err error
|
||||
logFormat, err = logging.ParseLogFormat(config.LogFormat)
|
||||
if err != nil {
|
||||
return level, logLevelString, logLevelWasNotSet, logFormat, err
|
||||
}
|
||||
}
|
||||
|
||||
return level, logLevelString, logLevelWasNotSet, logFormat, nil
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Run(args []string) int {
|
||||
f := c.Flags()
|
||||
|
||||
|
@ -373,6 +760,10 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
if c.flagRecovery {
|
||||
return c.runRecoveryMode()
|
||||
}
|
||||
|
||||
// Automatically enable dev mode if other dev flags are provided.
|
||||
if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode || c.flagDevFourCluster || c.flagDevAutoSeal || c.flagDevKVV1 {
|
||||
c.flagDev = true
|
||||
|
@ -413,18 +804,16 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
config.Listeners[0].Config["address"] = c.flagDevListenAddr
|
||||
}
|
||||
}
|
||||
for _, path := range c.flagConfigs {
|
||||
current, err := server.LoadConfig(path)
|
||||
|
||||
parsedConfig, err := c.parseConfig()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err))
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = current
|
||||
config = parsedConfig
|
||||
} else {
|
||||
config = config.Merge(current)
|
||||
}
|
||||
config = config.Merge(parsedConfig)
|
||||
}
|
||||
|
||||
// Ensure at least one config was found.
|
||||
|
@ -437,57 +826,11 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Create a logger. We wrap it in a gated writer so that it doesn't
|
||||
// start logging too early.
|
||||
c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
|
||||
c.logWriter = c.logGate
|
||||
if c.flagCombineLogs {
|
||||
c.logWriter = os.Stdout
|
||||
}
|
||||
var level log.Level
|
||||
var logLevelWasNotSet bool
|
||||
logLevelString := c.flagLogLevel
|
||||
c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
|
||||
switch c.flagLogLevel {
|
||||
case notSetValue, "":
|
||||
logLevelWasNotSet = true
|
||||
logLevelString = "info"
|
||||
level = log.Info
|
||||
case "trace":
|
||||
level = log.Trace
|
||||
case "debug":
|
||||
level = log.Debug
|
||||
case "notice", "info":
|
||||
level = log.Info
|
||||
case "warn", "warning":
|
||||
level = log.Warn
|
||||
case "err", "error":
|
||||
level = log.Error
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
|
||||
return 1
|
||||
}
|
||||
|
||||
logFormat := logging.UnspecifiedFormat
|
||||
if c.flagLogFormat != notSetValue {
|
||||
var err error
|
||||
logFormat, err = logging.ParseLogFormat(c.flagLogFormat)
|
||||
level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if logFormat == logging.UnspecifiedFormat {
|
||||
logFormat = logging.ParseEnvLogFormat()
|
||||
}
|
||||
if logFormat == logging.UnspecifiedFormat {
|
||||
var err error
|
||||
logFormat, err = logging.ParseLogFormat(config.LogFormat)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if c.flagDevThreeNode || c.flagDevFourCluster {
|
||||
c.logger = log.New(&log.LoggerOptions{
|
||||
|
@ -507,25 +850,13 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
allLoggers := []log.Logger{c.logger}
|
||||
|
||||
// adjust log level based on config setting
|
||||
if config.LogLevel != "" && logLevelWasNotSet {
|
||||
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
|
||||
logLevelString = configLogLevel
|
||||
switch configLogLevel {
|
||||
case "trace":
|
||||
c.logger.SetLevel(log.Trace)
|
||||
case "debug":
|
||||
c.logger.SetLevel(log.Debug)
|
||||
case "notice", "info", "":
|
||||
c.logger.SetLevel(log.Info)
|
||||
case "warn", "warning":
|
||||
c.logger.SetLevel(log.Warn)
|
||||
case "err", "error":
|
||||
c.logger.SetLevel(log.Error)
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Unknown log level: %s", config.LogLevel))
|
||||
logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
if logLevelStr != "" {
|
||||
logLevelString = logLevelStr
|
||||
}
|
||||
|
||||
// create GRPC logger
|
||||
|
@ -580,7 +911,6 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
|
||||
if config.Storage.Type == "raft" {
|
||||
if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
|
||||
config.ClusterAddr = envCA
|
||||
|
@ -1066,6 +1396,9 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
info["cgo"] = "enabled"
|
||||
}
|
||||
|
||||
infoKeys = append(infoKeys, "recovery mode")
|
||||
info["recovery mode"] = "false"
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
|
@ -1263,6 +1596,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
MaxRequestDuration: ln.maxRequestDuration,
|
||||
DisablePrintableCheck: config.DisablePrintableCheck,
|
||||
UnauthenticatedMetricsAccess: ln.unauthenticatedMetricsAccess,
|
||||
RecoveryMode: c.flagRecovery,
|
||||
})
|
||||
|
||||
// We perform validation on the config earlier, we can just cast here
|
||||
|
|
|
@ -19,16 +19,26 @@ import (
|
|||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
type GenerateRootKind int
|
||||
|
||||
const (
|
||||
GenerateRootRegular GenerateRootKind = iota
|
||||
GenerateRootDR
|
||||
GenerateRecovery
|
||||
)
|
||||
|
||||
// Generates a root token on the target cluster.
|
||||
func GenerateRoot(t testing.T, cluster *vault.TestCluster, drToken bool) string {
|
||||
token, err := GenerateRootWithError(t, cluster, drToken)
|
||||
func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string {
|
||||
t.Helper()
|
||||
token, err := GenerateRootWithError(t, cluster, kind)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool) (string, error) {
|
||||
func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) (string, error) {
|
||||
t.Helper()
|
||||
// If recovery keys supported, use those to perform root token generation instead
|
||||
var keys [][]byte
|
||||
if cluster.Cores[0].SealAccess().RecoveryKeySupported() {
|
||||
|
@ -36,13 +46,18 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool
|
|||
} else {
|
||||
keys = cluster.BarrierKeys
|
||||
}
|
||||
|
||||
client := cluster.Cores[0].Client
|
||||
f := client.Sys().GenerateRootInit
|
||||
if drToken {
|
||||
f = client.Sys().GenerateDROperationTokenInit
|
||||
|
||||
var err error
|
||||
var status *api.GenerateRootStatusResponse
|
||||
switch kind {
|
||||
case GenerateRootRegular:
|
||||
status, err = client.Sys().GenerateRootInit("", "")
|
||||
case GenerateRootDR:
|
||||
status, err = client.Sys().GenerateDROperationTokenInit("", "")
|
||||
case GenerateRecovery:
|
||||
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
|
||||
}
|
||||
status, err := f("", "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -57,11 +72,16 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool
|
|||
if i >= status.Required {
|
||||
break
|
||||
}
|
||||
f := client.Sys().GenerateRootUpdate
|
||||
if drToken {
|
||||
f = client.Sys().GenerateDROperationTokenUpdate
|
||||
|
||||
strKey := base64.StdEncoding.EncodeToString(key)
|
||||
switch kind {
|
||||
case GenerateRootRegular:
|
||||
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
|
||||
case GenerateRootDR:
|
||||
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
|
||||
case GenerateRecovery:
|
||||
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
|
||||
}
|
||||
status, err = f(base64.StdEncoding.EncodeToString(key), status.Nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -16,11 +17,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
assetfs "github.com/elazarl/go-bindata-assetfs"
|
||||
"github.com/hashicorp/errwrap"
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
|
@ -111,9 +111,15 @@ func Handler(props *vault.HandlerProperties) http.Handler {
|
|||
// Create the muxer to handle the actual endpoints
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Handle non-forwarded paths
|
||||
mux.Handle("/v1/sys/config/state/", handleLogicalNoForward(core))
|
||||
mux.Handle("/v1/sys/host-info", handleLogicalNoForward(core))
|
||||
switch {
|
||||
case props.RecoveryMode:
|
||||
raw := vault.NewRawBackend(core)
|
||||
strategy := vault.GenerateRecoveryTokenStrategy(props.RecoveryToken)
|
||||
mux.Handle("/v1/sys/raw/", handleLogicalRecovery(raw, props.RecoveryToken))
|
||||
mux.Handle("/v1/sys/generate-recovery-token/attempt", handleSysGenerateRootAttempt(core, strategy))
|
||||
mux.Handle("/v1/sys/generate-recovery-token/update", handleSysGenerateRootUpdate(core, strategy))
|
||||
default:
|
||||
// Handle pprof paths
|
||||
mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core))
|
||||
|
||||
mux.Handle("/v1/sys/init", handleSysInit(core))
|
||||
|
@ -146,6 +152,7 @@ func Handler(props *vault.HandlerProperties) http.Handler {
|
|||
}
|
||||
mux.Handle("/ui", handleUIRedirect())
|
||||
mux.Handle("/", handleUIRedirect())
|
||||
|
||||
}
|
||||
|
||||
// Register metrics path without authentication if enabled
|
||||
|
@ -154,6 +161,7 @@ func Handler(props *vault.HandlerProperties) http.Handler {
|
|||
}
|
||||
|
||||
additionalRoutes(mux, core)
|
||||
}
|
||||
|
||||
// Wrap the handler in another handler to trigger all help paths.
|
||||
helpWrappedHandler := wrapHelpHandler(mux, core)
|
||||
|
@ -489,7 +497,7 @@ func parseQuery(values url.Values) map[string]interface{} {
|
|||
return nil
|
||||
}
|
||||
|
||||
func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
|
||||
func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
|
||||
// Limit the maximum number of bytes to MaxRequestSize to protect
|
||||
// against an indefinite amount of data being read.
|
||||
reader := r.Body
|
||||
|
@ -505,7 +513,7 @@ func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out
|
|||
}
|
||||
}
|
||||
var origBody io.ReadWriter
|
||||
if core.PerfStandby() {
|
||||
if perfStandby {
|
||||
// Since we're checking PerfStandby here we key on origBody being nil
|
||||
// or not later, so we need to always allocate so it's non-nil
|
||||
origBody = new(bytes.Buffer)
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"go.uber.org/atomic"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -18,7 +20,7 @@ import (
|
|||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
|
||||
func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
|
||||
ns, err := namespace.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusBadRequest, nil
|
||||
|
@ -78,7 +80,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||
passHTTPReq = true
|
||||
origBody = r.Body
|
||||
} else {
|
||||
origBody, err = parseRequest(core, r, w, &data)
|
||||
origBody, err = parseRequest(perfStandby, r, w, &data)
|
||||
if err == io.EOF {
|
||||
data = nil
|
||||
err = nil
|
||||
|
@ -105,14 +107,32 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
|
||||
}
|
||||
|
||||
req, err := requestAuth(core, r, &logical.Request{
|
||||
req := &logical.Request{
|
||||
ID: request_id,
|
||||
Operation: op,
|
||||
Path: path,
|
||||
Data: data,
|
||||
Connection: getConnection(r),
|
||||
Headers: r.Header,
|
||||
})
|
||||
}
|
||||
|
||||
if passHTTPReq {
|
||||
req.HTTPRequest = r
|
||||
}
|
||||
if responseWriter != nil {
|
||||
req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter)
|
||||
}
|
||||
|
||||
return req, origBody, 0, nil
|
||||
}
|
||||
|
||||
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
|
||||
req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
|
||||
if err != nil {
|
||||
return nil, nil, status, err
|
||||
}
|
||||
|
||||
req, err = requestAuth(core, r, req)
|
||||
if err != nil {
|
||||
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
return nil, nil, http.StatusForbidden, nil
|
||||
|
@ -135,12 +155,6 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
|
||||
}
|
||||
|
||||
if passHTTPReq {
|
||||
req.HTTPRequest = r
|
||||
}
|
||||
if responseWriter != nil {
|
||||
req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter)
|
||||
}
|
||||
return req, origBody, 0, nil
|
||||
}
|
||||
|
||||
|
@ -168,6 +182,32 @@ func handleLogicalNoForward(core *vault.Core) http.Handler {
|
|||
return handleLogicalInternal(core, false, true)
|
||||
}
|
||||
|
||||
func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
}
|
||||
reqToken := r.Header.Get(consts.AuthHeaderName)
|
||||
if reqToken == "" || token.Load() == "" || reqToken != token.Load() {
|
||||
respondError(w, http.StatusForbidden, nil)
|
||||
}
|
||||
|
||||
resp, err := raw.HandleRequest(r.Context(), req)
|
||||
if respondErrorCommon(w, req, resp, err) {
|
||||
return
|
||||
}
|
||||
|
||||
var httpResp *logical.HTTPResponse
|
||||
if resp != nil {
|
||||
httpResp = logical.LogicalResponseToHTTPResponse(resp)
|
||||
httpResp.RequestID = req.ID
|
||||
}
|
||||
respondOk(w, httpResp)
|
||||
})
|
||||
}
|
||||
|
||||
// handleLogicalInternal is a common helper that returns a handler for
|
||||
// processing logical requests. The behavior depends on the various boolean
|
||||
// toggles. Refer to usage on functions for possible behaviors.
|
||||
|
|
|
@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r
|
|||
func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) {
|
||||
// Parse the request
|
||||
var req GenerateRootInitRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req GenerateRootUpdateRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request)
|
|||
|
||||
// Parse the request
|
||||
var req InitRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler {
|
|||
func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req JoinRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool,
|
|||
func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req RekeyRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {
|
|||
|
||||
// Parse the request
|
||||
var req RekeyUpdateRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery
|
|||
func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req RekeyVerificationUpdateRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
|
|||
|
||||
// Parse the request
|
||||
var req UnsealRequest
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/armon/go-metrics"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -12,12 +13,11 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-raftchunking"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/raft"
|
||||
snapshot "github.com/hashicorp/raft-snapshot"
|
||||
raftboltdb "github.com/hashicorp/vault/physical/raft/logstore"
|
||||
|
@ -367,6 +367,26 @@ type SetupOpts struct {
|
|||
// StartAsLeader is used to specify this node should start as leader and
|
||||
// bypass the leader election. This should be used with caution.
|
||||
StartAsLeader bool
|
||||
|
||||
// RecoveryModeConfig is the configuration for the raft cluster in recovery
|
||||
// mode.
|
||||
RecoveryModeConfig *raft.Configuration
|
||||
}
|
||||
|
||||
func (b *RaftBackend) StartRecoveryCluster(ctx context.Context, peer Peer) error {
|
||||
recoveryModeConfig := &raft.Configuration{
|
||||
Servers: []raft.Server{
|
||||
{
|
||||
ID: raft.ServerID(peer.ID),
|
||||
Address: raft.ServerAddress(peer.Address),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return b.SetupCluster(context.Background(), SetupOpts{
|
||||
StartAsLeader: true,
|
||||
RecoveryModeConfig: recoveryModeConfig,
|
||||
})
|
||||
}
|
||||
|
||||
// SetupCluster starts the raft cluster and enables the networking needed for
|
||||
|
@ -477,6 +497,13 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error {
|
|||
b.logger.Info("raft recovery deleted peers.json")
|
||||
}
|
||||
|
||||
if opts.RecoveryModeConfig != nil {
|
||||
err = raft.RecoverCluster(raftConfig, b.fsm, b.logStore, b.stableStore, b.snapStore, b.raftTransport, *opts.RecoveryModeConfig)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("recovering raft cluster failed: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
raftObj, err := raft.NewRaft(raftConfig, b.fsm.chunker, b.logStore, b.stableStore, b.snapStore, b.raftTransport)
|
||||
b.fsm.SetNoopRestore(false)
|
||||
if err != nil {
|
||||
|
|
|
@ -423,6 +423,10 @@ type Core struct {
|
|||
// Stores any funcs that should be run on successful postUnseal
|
||||
postUnsealFuncs []func()
|
||||
|
||||
// Stores any funcs that should be run on successful barrier unseal in
|
||||
// recovery mode
|
||||
postRecoveryUnsealFuncs []func() error
|
||||
|
||||
// replicationFailure is used to mark when replication has entered an
|
||||
// unrecoverable failure.
|
||||
replicationFailure *uint32
|
||||
|
@ -465,6 +469,8 @@ type Core struct {
|
|||
rawConfig *server.Config
|
||||
|
||||
coreNumber int
|
||||
|
||||
recoveryMode bool
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -542,6 +548,8 @@ type CoreConfig struct {
|
|||
MetricsHelper *metricsutil.MetricsHelper
|
||||
|
||||
CounterSyncInterval time.Duration
|
||||
|
||||
RecoveryMode bool
|
||||
}
|
||||
|
||||
func (c *CoreConfig) Clone() *CoreConfig {
|
||||
|
@ -668,6 +676,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
requests: new(uint64),
|
||||
syncInterval: syncInterval,
|
||||
},
|
||||
recoveryMode: conf.RecoveryMode,
|
||||
}
|
||||
|
||||
atomic.StoreUint32(c.sealed, 1)
|
||||
|
@ -726,25 +735,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
var err error
|
||||
|
||||
if conf.PluginDirectory != "" {
|
||||
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Construct a new AES-GCM barrier
|
||||
c.barrier, err = NewAESGCMBarrier(c.physical)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("barrier setup failed: {{err}}", err)
|
||||
}
|
||||
|
||||
createSecondaries(c, conf)
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
c.ha = conf.HAPhysical
|
||||
}
|
||||
|
||||
// We create the funcs here, then populate the given config with it so that
|
||||
// the caller can share state
|
||||
conf.ReloadFuncsLock = &c.reloadFuncsLock
|
||||
|
@ -753,6 +749,25 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
c.reloadFuncsLock.Unlock()
|
||||
conf.ReloadFuncs = &c.reloadFuncs
|
||||
|
||||
// All the things happening below this are not required in
|
||||
// recovery mode
|
||||
if c.recoveryMode {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
if conf.PluginDirectory != "" {
|
||||
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
createSecondaries(c, conf)
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
c.ha = conf.HAPhysical
|
||||
}
|
||||
|
||||
logicalBackends := make(map[string]logical.Factory)
|
||||
for k, f := range conf.LogicalBackends {
|
||||
logicalBackends[k] = f
|
||||
|
|
|
@ -218,7 +218,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
|
|||
} else {
|
||||
// We haven't finished, so generating a root token should still be the
|
||||
// old keys (which are still currently set)
|
||||
testhelpers.GenerateRoot(t, cluster, false)
|
||||
testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular)
|
||||
}
|
||||
|
||||
// Provide the final new key
|
||||
|
@ -256,7 +256,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
|
|||
}
|
||||
} else {
|
||||
// The old keys should no longer work
|
||||
_, err := testhelpers.GenerateRootWithError(t, cluster, false)
|
||||
_, err := testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRootRegular)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
@ -273,6 +273,6 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
|
|||
if err := client.Sys().GenerateRootCancel(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testhelpers.GenerateRoot(t, cluster, false)
|
||||
testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package token
|
||||
package misc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package token
|
||||
package misc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
|
140
vault/external_tests/misc/recovery_test.go
Normal file
140
vault/external_tests/misc/recovery_test.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package misc
|
||||
|
||||
import (
|
||||
"github.com/go-test/deep"
|
||||
"go.uber.org/atomic"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/testhelpers"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(hclog.Debug).Named(t.Name())
|
||||
inm, err := inmem.NewInmemHA(nil, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var keys [][]byte
|
||||
var secretUUID string
|
||||
var rootToken string
|
||||
{
|
||||
conf := vault.CoreConfig{
|
||||
Physical: inm,
|
||||
Logger: logger,
|
||||
}
|
||||
opts := vault.TestClusterOptions{
|
||||
HandlerFunc: http.Handler,
|
||||
NumCores: 1,
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, &conf, &opts)
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
client := cluster.Cores[0].Client
|
||||
rootToken = client.Token()
|
||||
var fooVal = map[string]interface{}{"bar": 1.0}
|
||||
_, err = client.Logical().Write("secret/foo", fooVal)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secret, err := client.Logical().List("secret/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 {
|
||||
t.Fatalf("got=%v, want=%v, diff: %v", secret.Data["keys"], []string{"foo"}, diff)
|
||||
}
|
||||
mounts, err := cluster.Cores[0].Client.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secretMount := mounts["secret/"]
|
||||
if secretMount == nil {
|
||||
t.Fatalf("secret mount not found, mounts: %v", mounts)
|
||||
}
|
||||
secretUUID = secretMount.UUID
|
||||
cluster.EnsureCoresSealed(t)
|
||||
keys = cluster.BarrierKeys
|
||||
}
|
||||
|
||||
{
|
||||
// Now bring it up in recovery mode.
|
||||
var tokenRef atomic.String
|
||||
conf := vault.CoreConfig{
|
||||
Physical: inm,
|
||||
Logger: logger,
|
||||
RecoveryMode: true,
|
||||
}
|
||||
opts := vault.TestClusterOptions{
|
||||
HandlerFunc: http.Handler,
|
||||
NumCores: 1,
|
||||
SkipInit: true,
|
||||
DefaultHandlerProperties: vault.HandlerProperties{
|
||||
RecoveryMode: true,
|
||||
RecoveryToken: &tokenRef,
|
||||
},
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, &conf, &opts)
|
||||
cluster.BarrierKeys = keys
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
client := cluster.Cores[0].Client
|
||||
recoveryToken := testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRecovery)
|
||||
_, err = testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRecovery)
|
||||
if err == nil {
|
||||
t.Fatal("expected second generate-root to fail")
|
||||
}
|
||||
client.SetToken(recoveryToken)
|
||||
|
||||
secret, err := client.Logical().List(path.Join("sys/raw/logical", secretUUID))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 {
|
||||
t.Fatalf("got=%v, want=%v, diff: %v", secret.Data, []string{"foo"}, diff)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Delete(path.Join("sys/raw/logical", secretUUID, "foo"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cluster.EnsureCoresSealed(t)
|
||||
}
|
||||
|
||||
{
|
||||
// Now go back to regular mode and verify that our changes are present
|
||||
conf := vault.CoreConfig{
|
||||
Physical: inm,
|
||||
Logger: logger,
|
||||
}
|
||||
opts := vault.TestClusterOptions{
|
||||
HandlerFunc: http.Handler,
|
||||
NumCores: 1,
|
||||
SkipInit: true,
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, &conf, &opts)
|
||||
cluster.BarrierKeys = keys
|
||||
cluster.Start()
|
||||
testhelpers.EnsureCoresUnsealed(t, cluster)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
client := cluster.Cores[0].Client
|
||||
client.SetToken(rootToken)
|
||||
secret, err := client.Logical().List("secret/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret != nil {
|
||||
t.Fatal("expected no data in secret mount")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,10 +4,10 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/pgpkeys"
|
||||
"github.com/hashicorp/vault/helper/xor"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
|
@ -77,10 +77,10 @@ type GenerateRootResult struct {
|
|||
func (c *Core) GenerateRootProgress() (int, error) {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.Sealed() {
|
||||
if c.Sealed() && !c.recoveryMode {
|
||||
return 0, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
if c.standby && !c.recoveryMode {
|
||||
return 0, consts.ErrStandby
|
||||
}
|
||||
|
||||
|
@ -95,10 +95,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
|
|||
func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.Sealed() {
|
||||
if c.Sealed() && !c.recoveryMode {
|
||||
return nil, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
if c.standby && !c.recoveryMode {
|
||||
return nil, consts.ErrStandby
|
||||
}
|
||||
|
||||
|
@ -141,10 +141,17 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg
|
|||
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.Sealed() {
|
||||
if c.Sealed() && !c.recoveryMode {
|
||||
return consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
barrierSealed, err := c.barrier.Sealed()
|
||||
if err != nil {
|
||||
return errors.New("unable to check barrier seal status")
|
||||
}
|
||||
if !barrierSealed && c.recoveryMode {
|
||||
return errors.New("attempt to generate recovery operation token when already unsealed")
|
||||
}
|
||||
if c.standby && !c.recoveryMode {
|
||||
return consts.ErrStandby
|
||||
}
|
||||
|
||||
|
@ -174,6 +181,8 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg
|
|||
switch strategy.(type) {
|
||||
case generateStandardRootToken:
|
||||
c.logger.Info("root generation initialized", "nonce", c.generateRootConfig.Nonce)
|
||||
case *generateRecoveryToken:
|
||||
c.logger.Info("recovery operation token generation initialized", "nonce", c.generateRootConfig.Nonce)
|
||||
default:
|
||||
c.logger.Info("dr operation token generation initialized", "nonce", c.generateRootConfig.Nonce)
|
||||
}
|
||||
|
@ -217,10 +226,19 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
|
|||
// Ensure we are already unsealed
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.Sealed() {
|
||||
if c.Sealed() && !c.recoveryMode {
|
||||
return nil, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
|
||||
barrierSealed, err := c.barrier.Sealed()
|
||||
if err != nil {
|
||||
return nil, errors.New("unable to check barrier seal status")
|
||||
}
|
||||
if !barrierSealed && c.recoveryMode {
|
||||
return nil, errors.New("attempt to generate recovery operation token when already unsealed")
|
||||
}
|
||||
|
||||
if c.standby && !c.recoveryMode {
|
||||
return nil, consts.ErrStandby
|
||||
}
|
||||
|
||||
|
@ -263,31 +281,82 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Recover the master key
|
||||
var masterKey []byte
|
||||
// Combine the key parts
|
||||
var combinedKey []byte
|
||||
if config.SecretThreshold == 1 {
|
||||
masterKey = c.generateRootProgress[0]
|
||||
combinedKey = c.generateRootProgress[0]
|
||||
c.generateRootProgress = nil
|
||||
} else {
|
||||
masterKey, err = shamir.Combine(c.generateRootProgress)
|
||||
combinedKey, err = shamir.Combine(c.generateRootProgress)
|
||||
c.generateRootProgress = nil
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to compute master key: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the master key
|
||||
if c.seal.RecoveryKeySupported() {
|
||||
if err := c.seal.VerifyRecoveryKey(ctx, masterKey); err != nil {
|
||||
switch {
|
||||
case c.seal.RecoveryKeySupported():
|
||||
// Ensure that the combined recovery key is valid
|
||||
if err := c.seal.VerifyRecoveryKey(ctx, combinedKey); err != nil {
|
||||
c.logger.Error("root generation aborted, recovery key verification failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := c.barrier.VerifyMaster(masterKey); err != nil {
|
||||
// If we are in recovery mode, then retrieve
|
||||
// the stored keys and unseal the barrier
|
||||
if c.recoveryMode {
|
||||
if !c.seal.StoredKeysSupported() {
|
||||
c.logger.Error("root generation aborted, recovery key verified but stored keys unsupported")
|
||||
return nil, errors.New("recovery key verified but stored keys unsupported")
|
||||
}
|
||||
masterKeyShares, err := c.seal.GetStoredKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("unable to retrieve stored keys in recovery mode: {{err}}", err)
|
||||
}
|
||||
|
||||
switch len(masterKeyShares) {
|
||||
case 0:
|
||||
return nil, errors.New("seal returned no master key shares in recovery mode")
|
||||
case 1:
|
||||
combinedKey = masterKeyShares[0]
|
||||
default:
|
||||
combinedKey, err = shamir.Combine(masterKeyShares)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to compute master key in recovery mode: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Use the retrieved master key to unseal the barrier
|
||||
if err := c.barrier.Unseal(ctx, combinedKey); err != nil {
|
||||
c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
default:
|
||||
switch {
|
||||
case c.recoveryMode:
|
||||
// If we are in recovery mode, being able to unseal
|
||||
// the barrier is how we establish authentication
|
||||
if err := c.barrier.Unseal(ctx, combinedKey); err != nil {
|
||||
c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if err := c.barrier.VerifyMaster(combinedKey); err != nil {
|
||||
c.logger.Error("root generation aborted, master key verification failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Authentication in recovery mode is successful
|
||||
if c.recoveryMode {
|
||||
// Run any post unseal functions that are set
|
||||
for _, v := range c.postRecoveryUnsealFuncs {
|
||||
if err := v(); err != nil {
|
||||
return nil, errwrap.Wrapf("failed to run post unseal func: {{err}}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run the generate strategy
|
||||
token, cleanupFunc, err := strategy.generate(ctx, c)
|
||||
|
@ -334,14 +403,12 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
|
|||
|
||||
switch strategy.(type) {
|
||||
case generateStandardRootToken:
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("root generation finished", "nonce", c.generateRootConfig.Nonce)
|
||||
}
|
||||
case *generateRecoveryToken:
|
||||
c.logger.Info("recovery operation token generation finished", "nonce", c.generateRootConfig.Nonce)
|
||||
default:
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("dr operation token generation finished", "nonce", c.generateRootConfig.Nonce)
|
||||
}
|
||||
}
|
||||
|
||||
c.generateRootProgress = nil
|
||||
c.generateRootConfig = nil
|
||||
|
@ -352,10 +419,10 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
|
|||
func (c *Core) GenerateRootCancel() error {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.Sealed() {
|
||||
if c.Sealed() && !c.recoveryMode {
|
||||
return consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
if c.standby && !c.recoveryMode {
|
||||
return consts.ErrStandby
|
||||
}
|
||||
|
||||
|
|
31
vault/generate_root_recovery.go
Normal file
31
vault/generate_root_recovery.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// GenerateRecoveryTokenStrategy is the strategy used to generate a
|
||||
// recovery token
|
||||
func GenerateRecoveryTokenStrategy(token *atomic.String) GenerateRootStrategy {
|
||||
return &generateRecoveryToken{token: token}
|
||||
}
|
||||
|
||||
// generateRecoveryToken implements the GenerateRootStrategy and is in
|
||||
// charge of creating recovery tokens.
|
||||
type generateRecoveryToken struct {
|
||||
token *atomic.String
|
||||
}
|
||||
|
||||
func (g *generateRecoveryToken) generate(ctx context.Context, c *Core) (string, func(), error) {
|
||||
id, err := base62.Random(TokenLength)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
token := "r." + id
|
||||
g.token.Store(token)
|
||||
|
||||
return token, func() { g.token.Store("") }, nil
|
||||
}
|
|
@ -38,6 +38,31 @@ var (
|
|||
initInProgress uint32
|
||||
)
|
||||
|
||||
func (c *Core) InitializeRecovery(ctx context.Context) error {
|
||||
if !c.recoveryMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedClusterAddr, err := url.Parse(c.ClusterAddr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.postRecoveryUnsealFuncs = append(c.postRecoveryUnsealFuncs, func() error {
|
||||
return raftStorage.StartRecoveryCluster(context.Background(), raft.Peer{
|
||||
ID: raftStorage.NodeID(),
|
||||
Address: parsedClusterAddr.Host,
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialized checks if the Vault is already initialized
|
||||
func (c *Core) Initialized(ctx context.Context) (bool, error) {
|
||||
// Check the barrier first
|
||||
|
|
216
vault/logical_raw.go
Normal file
216
vault/logical_raw.go
Normal file
|
@ -0,0 +1,216 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// protectedPaths cannot be accessed via the raw APIs.
|
||||
// This is both for security and to prevent disrupting Vault.
|
||||
protectedPaths = []string{
|
||||
keyringPath,
|
||||
// Changing the cluster info path can change the cluster ID which can be disruptive
|
||||
coreLocalClusterInfoPath,
|
||||
}
|
||||
)
|
||||
|
||||
type RawBackend struct {
|
||||
*framework.Backend
|
||||
barrier SecurityBarrier
|
||||
logger log.Logger
|
||||
checkRaw func(path string) error
|
||||
recoveryMode bool
|
||||
}
|
||||
|
||||
func NewRawBackend(core *Core) *RawBackend {
|
||||
r := &RawBackend{
|
||||
barrier: core.barrier,
|
||||
logger: core.logger.Named("raw"),
|
||||
checkRaw: func(path string) error {
|
||||
return nil
|
||||
},
|
||||
recoveryMode: core.recoveryMode,
|
||||
}
|
||||
r.Backend = &framework.Backend{
|
||||
Paths: rawPaths("sys/", r),
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// handleRawRead is used to read directly from the barrier
|
||||
func (b *RawBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
if b.recoveryMode {
|
||||
b.logger.Info("reading", "path", path)
|
||||
}
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot read '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Run additional checks if needed
|
||||
if err := b.checkRaw(path); err != nil {
|
||||
b.logger.Warn(err.Error(), "path", path)
|
||||
return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
entry, err := b.barrier.Get(ctx, path)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Run this through the decompression helper to see if it's been compressed.
|
||||
// If the input contained the compression canary, `outputBytes` will hold
|
||||
// the decompressed data. If the input was not compressed, then `outputBytes`
|
||||
// will be nil.
|
||||
outputBytes, _, err := compressutil.Decompress(entry.Value)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
|
||||
// `outputBytes` is nil if the input is uncompressed. In that case set it to the original input.
|
||||
if outputBytes == nil {
|
||||
outputBytes = entry.Value
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"value": string(outputBytes),
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// handleRawWrite is used to write directly to the barrier
|
||||
func (b *RawBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
if b.recoveryMode {
|
||||
b.logger.Info("writing", "path", path)
|
||||
}
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot write '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
value := data.Get("value").(string)
|
||||
entry := &logical.StorageEntry{
|
||||
Key: path,
|
||||
Value: []byte(value),
|
||||
}
|
||||
if err := b.barrier.Put(ctx, entry); err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRawDelete is used to delete directly from the barrier
|
||||
func (b *RawBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
if b.recoveryMode {
|
||||
b.logger.Info("deleting", "path", path)
|
||||
}
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot delete '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.barrier.Delete(ctx, path); err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRawList is used to list directly from the barrier
|
||||
func (b *RawBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
if path != "" && !strings.HasSuffix(path, "/") {
|
||||
path = path + "/"
|
||||
}
|
||||
|
||||
if b.recoveryMode {
|
||||
b.logger.Info("listing", "path", path)
|
||||
}
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot list '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Run additional checks if needed
|
||||
if err := b.checkRaw(path); err != nil {
|
||||
b.logger.Warn(err.Error(), "path", path)
|
||||
return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
keys, err := b.barrier.List(ctx, path)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
return logical.ListResponse(keys), nil
|
||||
}
|
||||
|
||||
func rawPaths(prefix string, r *RawBackend) []*framework.Path {
|
||||
return []*framework.Path{
|
||||
&framework.Path{
|
||||
Pattern: prefix + "(raw/?$|raw/(?P<path>.+))",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"path": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
},
|
||||
"value": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
},
|
||||
},
|
||||
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.ReadOperation: &framework.PathOperation{
|
||||
Callback: r.handleRawRead,
|
||||
Summary: "Read the value of the key at the given path.",
|
||||
},
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: r.handleRawWrite,
|
||||
Summary: "Update the value of the key at the given path.",
|
||||
},
|
||||
logical.DeleteOperation: &framework.PathOperation{
|
||||
Callback: r.handleRawDelete,
|
||||
Summary: "Delete the key with given path.",
|
||||
},
|
||||
logical.ListOperation: &framework.PathOperation{
|
||||
Callback: r.handleRawList,
|
||||
Summary: "Return a list keys for a given path prefix.",
|
||||
},
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["raw"][1]),
|
||||
},
|
||||
}
|
||||
}
|
|
@ -23,14 +23,13 @@ import (
|
|||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
memdb "github.com/hashicorp/go-memdb"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/hostutil"
|
||||
"github.com/hashicorp/vault/helper/identity"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
|
@ -40,16 +39,6 @@ import (
|
|||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
var (
|
||||
// protectedPaths cannot be accessed via the raw APIs.
|
||||
// This is both for security and to prevent disrupting Vault.
|
||||
protectedPaths = []string{
|
||||
keyringPath,
|
||||
// Changing the cluster info path can change the cluster ID which can be disruptive
|
||||
coreLocalClusterInfoPath,
|
||||
}
|
||||
)
|
||||
|
||||
const maxBytes = 128 * 1024
|
||||
|
||||
func systemBackendMemDBSchema() *memdb.DBSchema {
|
||||
|
@ -172,40 +161,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
|
|||
b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath())
|
||||
|
||||
if core.rawEnabled {
|
||||
b.Backend.Paths = append(b.Backend.Paths, &framework.Path{
|
||||
Pattern: "(raw/?$|raw/(?P<path>.+))",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"path": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
},
|
||||
"value": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
},
|
||||
},
|
||||
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.ReadOperation: &framework.PathOperation{
|
||||
Callback: b.handleRawRead,
|
||||
Summary: "Read the value of the key at the given path.",
|
||||
},
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: b.handleRawWrite,
|
||||
Summary: "Update the value of the key at the given path.",
|
||||
},
|
||||
logical.DeleteOperation: &framework.PathOperation{
|
||||
Callback: b.handleRawDelete,
|
||||
Summary: "Delete the key with given path.",
|
||||
},
|
||||
logical.ListOperation: &framework.PathOperation{
|
||||
Callback: b.handleRawList,
|
||||
Summary: "Return a list keys for a given path prefix.",
|
||||
},
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["raw"][1]),
|
||||
})
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
|
||||
}
|
||||
|
||||
if _, ok := core.underlyingPhysical.(*raft.RaftBackend); ok {
|
||||
|
@ -216,6 +172,17 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
|
|||
return b
|
||||
}
|
||||
|
||||
func (b *SystemBackend) rawPaths() []*framework.Path {
|
||||
r := &RawBackend{
|
||||
barrier: b.Core.barrier,
|
||||
logger: b.logger,
|
||||
checkRaw: func(path string) error {
|
||||
return checkRaw(b, path)
|
||||
},
|
||||
}
|
||||
return rawPaths("", r)
|
||||
}
|
||||
|
||||
// SystemBackend implements logical.Backend and is used to interact with
|
||||
// the core of the system. This backend is hardcoded to exist at the "sys"
|
||||
// prefix. Conceptually it is similar to procfs on Linux.
|
||||
|
@ -2248,123 +2215,6 @@ func (b *SystemBackend) handleConfigUIHeadersDelete(ctx context.Context, req *lo
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRawRead is used to read directly from the barrier
|
||||
func (b *SystemBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot read '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Run additional checks if needed
|
||||
if err := checkRaw(b, path); err != nil {
|
||||
b.Core.logger.Warn(err.Error(), "path", path)
|
||||
return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
entry, err := b.Core.barrier.Get(ctx, path)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Run this through the decompression helper to see if it's been compressed.
|
||||
// If the input contained the compression canary, `outputBytes` will hold
|
||||
// the decompressed data. If the input was not compressed, then `outputBytes`
|
||||
// will be nil.
|
||||
outputBytes, _, err := compressutil.Decompress(entry.Value)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
|
||||
// `outputBytes` is nil if the input is uncompressed. In that case set it to the original input.
|
||||
if outputBytes == nil {
|
||||
outputBytes = entry.Value
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"value": string(outputBytes),
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// handleRawWrite is used to write directly to the barrier
|
||||
func (b *SystemBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot write '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
value := data.Get("value").(string)
|
||||
entry := &logical.StorageEntry{
|
||||
Key: path,
|
||||
Value: []byte(value),
|
||||
}
|
||||
if err := b.Core.barrier.Put(ctx, entry); err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRawDelete is used to delete directly from the barrier
|
||||
func (b *SystemBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot delete '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.Core.barrier.Delete(ctx, path); err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRawList is used to list directly from the barrier
|
||||
func (b *SystemBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
if path != "" && !strings.HasSuffix(path, "/") {
|
||||
path = path + "/"
|
||||
}
|
||||
|
||||
// Prevent access of protected paths
|
||||
for _, p := range protectedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
err := fmt.Sprintf("cannot list '%s'", path)
|
||||
return logical.ErrorResponse(err), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Run additional checks if needed
|
||||
if err := checkRaw(b, path); err != nil {
|
||||
b.Core.logger.Warn(err.Error(), "path", path)
|
||||
return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
keys, err := b.Core.barrier.List(ctx, path)
|
||||
if err != nil {
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
return logical.ListResponse(keys), nil
|
||||
}
|
||||
|
||||
// handleKeyStatus returns status information about the backend key
|
||||
func (b *SystemBackend) handleKeyStatus(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the key info
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
uberAtomic "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -43,6 +44,8 @@ type HandlerProperties struct {
|
|||
MaxRequestSize int64
|
||||
MaxRequestDuration time.Duration
|
||||
DisablePrintableCheck bool
|
||||
RecoveryMode bool
|
||||
RecoveryToken *uberAtomic.String
|
||||
UnauthenticatedMetricsAccess bool
|
||||
}
|
||||
|
||||
|
|
|
@ -1042,6 +1042,7 @@ type TestClusterOptions struct {
|
|||
KeepStandbysSealed bool
|
||||
SkipInit bool
|
||||
HandlerFunc func(*HandlerProperties) http.Handler
|
||||
DefaultHandlerProperties HandlerProperties
|
||||
BaseListenAddress string
|
||||
NumCores int
|
||||
SealFunc func() Seal
|
||||
|
@ -1417,7 +1418,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
|
||||
coreConfig.DevToken = base.DevToken
|
||||
coreConfig.CounterSyncInterval = base.CounterSyncInterval
|
||||
|
||||
coreConfig.RecoveryMode = base.RecoveryMode
|
||||
}
|
||||
|
||||
if coreConfig.RawConfig == nil {
|
||||
|
@ -1511,10 +1512,12 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
cores = append(cores, c)
|
||||
coreConfigs = append(coreConfigs, &localConfig)
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
handlers[i] = opts.HandlerFunc(&HandlerProperties{
|
||||
Core: c,
|
||||
MaxRequestDuration: DefaultMaxRequestDuration,
|
||||
})
|
||||
props := opts.DefaultHandlerProperties
|
||||
props.Core = c
|
||||
if props.MaxRequestDuration == 0 {
|
||||
props.MaxRequestDuration = DefaultMaxRequestDuration
|
||||
}
|
||||
handlers[i] = opts.HandlerFunc(&props)
|
||||
servers[i].Handler = handlers[i]
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue