Recovery Mode (#7559)

* Initial work

* rework

* s/dr/recovery

* Add sys/raw support to recovery mode (#7577)

* Factor the raw paths out so they can be run with a SystemBackend.

# Conflicts:
#	vault/logical_system.go

* Add handleLogicalRecovery which is like handleLogical but is only
sufficient for use with the sys-raw endpoint in recovery mode.  No
authentication is done yet.

* Integrate with recovery-mode.  We now handle unauthenticated sys/raw
requests, albeit on path v1/raw instead v1/sys/raw.

* Use sys/raw instead raw during recovery.

* Don't bother persisting the recovery token.  Authenticate sys/raw
requests with it.

* RecoveryMode: Support generate-root for autounseals (#7591)

* Recovery: Abstract config creation and log settings

* Recovery mode integration test. (#7600)

* Recovery: Touch up (#7607)

* Recovery: Touch up

* revert the raw backend creation changes

* Added recovery operation token prefix

* Move RawBackend to its own file

* Update API path and hit it using CLI flag on generate-root

* Fix a panic triggered when handling a request that yields a nil response. (#7618)

* Improve integ test to actually make changes while in recovery mode and
verify they're still there after coming back in regular mode.

* Refuse to allow a second recovery token to be generated.

* Resize raft cluster to size 1 and start as leader (#7626)

* RecoveryMode: Setup raft cluster post unseal (#7635)

* Setup raft cluster post unseal in recovery mode

* Remove marking as unsealed as its not needed

* Address review comments

* Accept only one seal config in recovery mode as there is no scope for migration
This commit is contained in:
Vishal Nayak 2019-10-15 00:55:31 -04:00 committed by GitHub
parent ffb699e48c
commit 0d077d7945
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 1264 additions and 413 deletions

View file

@ -10,6 +10,10 @@ func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, err
return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
}
func (c *Sys) GenerateRecoveryOperationTokenStatus() (*GenerateRootStatusResponse, error) {
return c.generateRootStatusCommon("/v1/sys/generate-recovery-token/attempt")
}
func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) {
r := c.c.NewRequest("GET", path)
@ -34,6 +38,10 @@ func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootSta
return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey)
}
func (c *Sys) GenerateRecoveryOperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
return c.generateRootInitCommon("/v1/sys/generate-recovery-token/attempt", otp, pgpKey)
}
func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) {
body := map[string]interface{}{
"otp": otp,
@ -66,6 +74,10 @@ func (c *Sys) GenerateDROperationTokenCancel() error {
return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
}
func (c *Sys) GenerateRecoveryOperationTokenCancel() error {
return c.generateRootCancelCommon("/v1/sys/generate-recovery-token/attempt")
}
func (c *Sys) generateRootCancelCommon(path string) error {
r := c.c.NewRequest("DELETE", path)
@ -86,6 +98,10 @@ func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRoot
return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce)
}
func (c *Sys) GenerateRecoveryOperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
return c.generateRootUpdateCommon("/v1/sys/generate-recovery-token/update", shard, nonce)
}
func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) {
body := map[string]interface{}{
"key": shard,

View file

@ -23,6 +23,14 @@ import (
var _ cli.Command = (*OperatorGenerateRootCommand)(nil)
var _ cli.CommandAutocomplete = (*OperatorGenerateRootCommand)(nil)
type generateRootKind int
const (
generateRootRegular generateRootKind = iota
generateRootDR
generateRootRecovery
)
type OperatorGenerateRootCommand struct {
*BaseCommand
@ -35,6 +43,7 @@ type OperatorGenerateRootCommand struct {
flagNonce string
flagGenerateOTP bool
flagDRToken bool
flagRecoveryToken bool
testStdin io.Reader // for tests
}
@ -143,6 +152,16 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets {
"tokens.",
})
f.BoolVar(&BoolVar{
Name: "recovery-token",
Target: &c.flagRecoveryToken,
Default: false,
EnvVar: "",
Completion: complete.PredictNothing,
Usage: "Set this flag to do generate root operations on Recovery Operational " +
"tokens.",
})
f.StringVar(&StringVar{
Name: "otp",
Target: &c.flagOTP,
@ -200,43 +219,60 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
return 1
}
if c.flagDRToken && c.flagRecoveryToken {
c.UI.Error("Both -recovery-token and -dr-token flags are set")
return 1
}
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
kind := generateRootRegular
switch {
case c.flagDRToken:
kind = generateRootDR
case c.flagRecoveryToken:
kind = generateRootRecovery
}
switch {
case c.flagGenerateOTP:
otp, code := c.generateOTP(client, c.flagDRToken)
otp, code := c.generateOTP(client, kind)
if code == 0 {
return PrintRaw(c.UI, otp)
}
return code
case c.flagDecode != "":
return c.decode(client, c.flagDecode, c.flagOTP, c.flagDRToken)
return c.decode(client, c.flagDecode, c.flagOTP, kind)
case c.flagCancel:
return c.cancel(client, c.flagDRToken)
return c.cancel(client, kind)
case c.flagInit:
return c.init(client, c.flagOTP, c.flagPGPKey, c.flagDRToken)
return c.init(client, c.flagOTP, c.flagPGPKey, kind)
case c.flagStatus:
return c.status(client, c.flagDRToken)
return c.status(client, kind)
default:
// If there are no other flags, prompt for an unseal key.
key := ""
if len(args) > 0 {
key = strings.TrimSpace(args[0])
}
return c.provide(client, key, c.flagDRToken)
return c.provide(client, key, kind)
}
}
// generateOTP generates a suitable OTP code for generating a root token.
func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bool) (string, int) {
func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, kind generateRootKind) (string, int) {
f := client.Sys().GenerateRootStatus
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenStatus
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenStatus
}
status, err := f()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
@ -272,7 +308,7 @@ func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bo
}
// decode decodes the given value using the otp.
func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, drToken bool) int {
func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, kind generateRootKind) int {
if encoded == "" {
c.UI.Error("Missing encoded value: use -decode=<string> to supply it")
return 1
@ -283,9 +319,13 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st
}
f := client.Sys().GenerateRootStatus
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenStatus
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenStatus
}
status, err := f()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
@ -327,7 +367,7 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st
}
// init is used to start the generation process
func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, drToken bool) int {
func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, kind generateRootKind) int {
// Validate incoming fields. Either OTP OR PGP keys must be supplied.
if otp != "" && pgpKey != "" {
c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key")
@ -336,8 +376,11 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin
// Start the root generation
f := client.Sys().GenerateRootInit
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenInit
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenInit
}
status, err := f(otp, pgpKey)
if err != nil {
@ -355,10 +398,13 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin
// provide prompts the user for the seal key and posts it to the update root
// endpoint. If this is the last unseal, this function outputs it.
func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, drToken bool) int {
func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, kind generateRootKind) int {
f := client.Sys().GenerateRootStatus
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenStatus
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenStatus
}
status, err := f()
if err != nil {
@ -437,8 +483,11 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr
// Provide the key, this may potentially complete the update
fUpd := client.Sys().GenerateRootUpdate
if drToken {
switch kind {
case generateRootDR:
fUpd = client.Sys().GenerateDROperationTokenUpdate
case generateRootRecovery:
fUpd = client.Sys().GenerateRecoveryOperationTokenUpdate
}
status, err = fUpd(key, nonce)
if err != nil {
@ -454,10 +503,13 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr
}
// cancel cancels the root token generation
func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) int {
func (c *OperatorGenerateRootCommand) cancel(client *api.Client, kind generateRootKind) int {
f := client.Sys().GenerateRootCancel
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenCancel
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenCancel
}
if err := f(); err != nil {
c.UI.Error(fmt.Sprintf("Error canceling root token generation: %s", err))
@ -468,11 +520,15 @@ func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) i
}
// status is used just to fetch and dump the status
func (c *OperatorGenerateRootCommand) status(client *api.Client, drToken bool) int {
func (c *OperatorGenerateRootCommand) status(client *api.Client, kind generateRootKind) int {
f := client.Sys().GenerateRootStatus
if drToken {
switch kind {
case generateRootDR:
f = client.Sys().GenerateDROperationTokenStatus
case generateRootRecovery:
f = client.Sys().GenerateRecoveryOperationTokenStatus
}
status, err := f()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))

View file

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"go.uber.org/atomic"
"io"
"io/ioutil"
"net"
@ -99,6 +100,7 @@ type ServerCommand struct {
flagConfigs []string
flagLogLevel string
flagLogFormat string
flagRecovery bool
flagDev bool
flagDevRootTokenID string
flagDevListenAddr string
@ -197,6 +199,13 @@ func (c *ServerCommand) Flags() *FlagSets {
Usage: `Log format. Supported values are "standard" and "json".`,
})
f.BoolVar(&BoolVar{
Name: "recovery",
Target: &c.flagRecovery,
Usage: "Enable recovery mode. In this mode, Vault is used to perform recovery actions." +
"Using a recovery operation token, \"sys/raw\" API can be used to manipulate the storage.",
})
f = set.NewFlagSet("Dev Options")
f.BoolVar(&BoolVar{
@ -365,6 +374,384 @@ func (c *ServerCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *ServerCommand) parseConfig() (*server.Config, error) {
// Load the configuration
var config *server.Config
for _, path := range c.flagConfigs {
current, err := server.LoadConfig(path)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error loading configuration from %s: {{err}}", path), err)
}
if config == nil {
config = current
} else {
config = config.Merge(current)
}
}
return config, nil
}
func (c *ServerCommand) runRecoveryMode() int {
config, err := c.parseConfig()
if err != nil {
c.UI.Error(err.Error())
return 1
}
// Ensure at least one config was found.
if config == nil {
c.UI.Output(wrapAtLength(
"No configuration files found. Please provide configurations with the " +
"-config flag. If you are supplying the path to a directory, please " +
"ensure the directory contains files with the .hcl or .json " +
"extension."))
return 1
}
level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config)
if err != nil {
c.UI.Error(err.Error())
return 1
}
c.logger = log.New(&log.LoggerOptions{
Output: c.logWriter,
Level: level,
// Note that if logFormat is either unspecified or standard, then
// the resulting logger's format will be standard.
JSONFormat: logFormat == logging.JSONFormat,
})
logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet)
if err != nil {
c.UI.Error(err.Error())
return 1
}
if logLevelStr != "" {
logLevelString = logLevelStr
}
// create GRPC logger
namedGRPCLogFaker := c.logger.Named("grpclogfaker")
grpclog.SetLogger(&grpclogFaker{
logger: namedGRPCLogFaker,
log: os.Getenv("VAULT_GRPC_LOGGING") != "",
})
if config.Storage == nil {
c.UI.Output("A storage backend must be specified")
return 1
}
if config.DefaultMaxRequestDuration != 0 {
vault.DefaultMaxRequestDuration = config.DefaultMaxRequestDuration
}
proxyCfg := httpproxy.FromEnvironment()
c.logger.Info("proxy environment", "http_proxy", proxyCfg.HTTPProxy,
"https_proxy", proxyCfg.HTTPSProxy, "no_proxy", proxyCfg.NoProxy)
// Initialize the storage backend
factory, exists := c.PhysicalBackends[config.Storage.Type]
if !exists {
c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type))
return 1
}
if config.Storage.Type == "raft" {
if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
config.ClusterAddr = envCA
}
if len(config.ClusterAddr) == 0 {
c.UI.Error("Cluster address must be set when using raft storage")
return 1
}
}
namedStorageLogger := c.logger.Named("storage." + config.Storage.Type)
backend, err := factory(config.Storage.Config, namedStorageLogger)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err))
return 1
}
infoKeys := make([]string, 0, 10)
info := make(map[string]string)
info["log level"] = logLevelString
infoKeys = append(infoKeys, "log level")
var barrierSeal vault.Seal
var sealConfigError error
if len(config.Seals) == 0 {
config.Seals = append(config.Seals, &server.Seal{Type: vaultseal.Shamir})
}
if len(config.Seals) > 1 {
c.UI.Error("Only one seal block is accepted in recovery mode")
return 1
}
configSeal := config.Seals[0]
sealType := vaultseal.Shamir
if !configSeal.Disabled && os.Getenv("VAULT_SEAL_TYPE") != "" {
sealType = os.Getenv("VAULT_SEAL_TYPE")
configSeal.Type = sealType
} else {
sealType = configSeal.Type
}
var seal vault.Seal
sealLogger := c.logger.Named(sealType)
seal, sealConfigError = serverseal.ConfigureSeal(configSeal, &infoKeys, &info, sealLogger, vault.NewDefaultSeal(shamirseal.NewSeal(c.logger.Named("shamir"))))
if sealConfigError != nil {
if !errwrap.ContainsType(sealConfigError, new(logical.KeyNotFoundError)) {
c.UI.Error(fmt.Sprintf(
"Error parsing Seal configuration: %s", sealConfigError))
return 1
}
}
if seal == nil {
c.UI.Error(fmt.Sprintf(
"After configuring seal nil returned, seal type was %s", sealType))
return 1
}
barrierSeal = seal
// Ensure that the seal finalizer is called, even if using verify-only
defer func() {
err = seal.Finalize(context.Background())
if err != nil {
c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err))
}
}()
coreConfig := &vault.CoreConfig{
Physical: backend,
StorageType: config.Storage.Type,
Seal: barrierSeal,
Logger: c.logger,
DisableMlock: config.DisableMlock,
RecoveryMode: c.flagRecovery,
ClusterAddr: config.ClusterAddr,
}
core, newCoreError := vault.NewCore(coreConfig)
if newCoreError != nil {
if vault.IsFatalError(newCoreError) {
c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError))
return 1
}
}
if err := core.InitializeRecovery(context.Background()); err != nil {
c.UI.Error(fmt.Sprintf("Error initializing core in recovery mode: %s", err))
return 1
}
// Compile server information for output later
infoKeys = append(infoKeys, "storage")
info["storage"] = config.Storage.Type
if coreConfig.ClusterAddr != "" {
info["cluster address"] = coreConfig.ClusterAddr
infoKeys = append(infoKeys, "cluster address")
}
// Initialize the listeners
lns := make([]ServerListener, 0, len(config.Listeners))
for _, lnConfig := range config.Listeners {
ln, _, _, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logWriter, c.UI)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err))
return 1
}
lns = append(lns, ServerListener{
Listener: ln,
config: lnConfig.Config,
})
}
listenerCloseFunc := func() {
for _, ln := range lns {
ln.Listener.Close()
}
}
defer c.cleanupGuard.Do(listenerCloseFunc)
infoKeys = append(infoKeys, "version")
verInfo := version.GetVersion()
info["version"] = verInfo.FullVersionNumber(false)
if verInfo.Revision != "" {
info["version sha"] = strings.Trim(verInfo.Revision, "'")
infoKeys = append(infoKeys, "version sha")
}
infoKeys = append(infoKeys, "recovery mode")
info["recovery mode"] = "true"
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault server configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
for _, ln := range lns {
handler := vaulthttp.Handler(&vault.HandlerProperties{
Core: core,
MaxRequestSize: ln.maxRequestSize,
MaxRequestDuration: ln.maxRequestDuration,
DisablePrintableCheck: config.DisablePrintableCheck,
RecoveryMode: c.flagRecovery,
RecoveryToken: atomic.NewString(""),
})
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
}
go server.Serve(ln.Listener)
}
if sealConfigError != nil {
init, err := core.Initialized(context.Background())
if err != nil {
c.UI.Error(fmt.Sprintf("Error checking if core is initialized: %v", err))
return 1
}
if init {
c.UI.Error("Vault is initialized but no Seal key could be loaded")
return 1
}
}
if newCoreError != nil {
c.UI.Warn(wrapAtLength(
"WARNING! A non-fatal error occurred during initialization. Please " +
"check the logs for more information."))
c.UI.Warn("")
}
if !c.flagCombineLogs {
c.UI.Output("==> Vault server started! Log data will stream in below:\n")
}
c.logGate.Flush()
for {
select {
case <-c.ShutdownCh:
c.UI.Output("==> Vault shutdown triggered")
c.cleanupGuard.Do(listenerCloseFunc)
if err := core.Shutdown(); err != nil {
c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err))
}
return 0
case <-c.SigUSR2Ch:
buf := make([]byte, 32*1024*1024)
n := runtime.Stack(buf[:], true)
c.logger.Info("goroutine trace", "stack", string(buf[:n]))
}
}
return 0
}
func (c *ServerCommand) adjustLogLevel(config *server.Config, logLevelWasNotSet bool) (string, error) {
var logLevelString string
if config.LogLevel != "" && logLevelWasNotSet {
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
logLevelString = configLogLevel
switch configLogLevel {
case "trace":
c.logger.SetLevel(log.Trace)
case "debug":
c.logger.SetLevel(log.Debug)
case "notice", "info", "":
c.logger.SetLevel(log.Info)
case "warn", "warning":
c.logger.SetLevel(log.Warn)
case "err", "error":
c.logger.SetLevel(log.Error)
default:
return "", fmt.Errorf("unknown log level: %s", config.LogLevel)
}
}
return logLevelString, nil
}
func (c *ServerCommand) processLogLevelAndFormat(config *server.Config) (log.Level, string, bool, logging.LogFormat, error) {
// Create a logger. We wrap it in a gated writer so that it doesn't
// start logging too early.
c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
c.logWriter = c.logGate
if c.flagCombineLogs {
c.logWriter = os.Stdout
}
var level log.Level
var logLevelWasNotSet bool
logFormat := logging.UnspecifiedFormat
logLevelString := c.flagLogLevel
c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
switch c.flagLogLevel {
case notSetValue, "":
logLevelWasNotSet = true
logLevelString = "info"
level = log.Info
case "trace":
level = log.Trace
case "debug":
level = log.Debug
case "notice", "info":
level = log.Info
case "warn", "warning":
level = log.Warn
case "err", "error":
level = log.Error
default:
return level, logLevelString, logLevelWasNotSet, logFormat, fmt.Errorf("unknown log level: %s", c.flagLogLevel)
}
if c.flagLogFormat != notSetValue {
var err error
logFormat, err = logging.ParseLogFormat(c.flagLogFormat)
if err != nil {
return level, logLevelString, logLevelWasNotSet, logFormat, err
}
}
if logFormat == logging.UnspecifiedFormat {
logFormat = logging.ParseEnvLogFormat()
}
if logFormat == logging.UnspecifiedFormat {
var err error
logFormat, err = logging.ParseLogFormat(config.LogFormat)
if err != nil {
return level, logLevelString, logLevelWasNotSet, logFormat, err
}
}
return level, logLevelString, logLevelWasNotSet, logFormat, nil
}
func (c *ServerCommand) Run(args []string) int {
f := c.Flags()
@ -373,6 +760,10 @@ func (c *ServerCommand) Run(args []string) int {
return 1
}
if c.flagRecovery {
return c.runRecoveryMode()
}
// Automatically enable dev mode if other dev flags are provided.
if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode || c.flagDevFourCluster || c.flagDevAutoSeal || c.flagDevKVV1 {
c.flagDev = true
@ -413,18 +804,16 @@ func (c *ServerCommand) Run(args []string) int {
config.Listeners[0].Config["address"] = c.flagDevListenAddr
}
}
for _, path := range c.flagConfigs {
current, err := server.LoadConfig(path)
parsedConfig, err := c.parseConfig()
if err != nil {
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err))
c.UI.Error(err.Error())
return 1
}
if config == nil {
config = current
config = parsedConfig
} else {
config = config.Merge(current)
}
config = config.Merge(parsedConfig)
}
// Ensure at least one config was found.
@ -437,57 +826,11 @@ func (c *ServerCommand) Run(args []string) int {
return 1
}
// Create a logger. We wrap it in a gated writer so that it doesn't
// start logging too early.
c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
c.logWriter = c.logGate
if c.flagCombineLogs {
c.logWriter = os.Stdout
}
var level log.Level
var logLevelWasNotSet bool
logLevelString := c.flagLogLevel
c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
switch c.flagLogLevel {
case notSetValue, "":
logLevelWasNotSet = true
logLevelString = "info"
level = log.Info
case "trace":
level = log.Trace
case "debug":
level = log.Debug
case "notice", "info":
level = log.Info
case "warn", "warning":
level = log.Warn
case "err", "error":
level = log.Error
default:
c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
return 1
}
logFormat := logging.UnspecifiedFormat
if c.flagLogFormat != notSetValue {
var err error
logFormat, err = logging.ParseLogFormat(c.flagLogFormat)
level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config)
if err != nil {
c.UI.Error(err.Error())
return 1
}
}
if logFormat == logging.UnspecifiedFormat {
logFormat = logging.ParseEnvLogFormat()
}
if logFormat == logging.UnspecifiedFormat {
var err error
logFormat, err = logging.ParseLogFormat(config.LogFormat)
if err != nil {
c.UI.Error(err.Error())
return 1
}
}
if c.flagDevThreeNode || c.flagDevFourCluster {
c.logger = log.New(&log.LoggerOptions{
@ -507,25 +850,13 @@ func (c *ServerCommand) Run(args []string) int {
allLoggers := []log.Logger{c.logger}
// adjust log level based on config setting
if config.LogLevel != "" && logLevelWasNotSet {
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
logLevelString = configLogLevel
switch configLogLevel {
case "trace":
c.logger.SetLevel(log.Trace)
case "debug":
c.logger.SetLevel(log.Debug)
case "notice", "info", "":
c.logger.SetLevel(log.Info)
case "warn", "warning":
c.logger.SetLevel(log.Warn)
case "err", "error":
c.logger.SetLevel(log.Error)
default:
c.UI.Error(fmt.Sprintf("Unknown log level: %s", config.LogLevel))
logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet)
if err != nil {
c.UI.Error(err.Error())
return 1
}
if logLevelStr != "" {
logLevelString = logLevelStr
}
// create GRPC logger
@ -580,7 +911,6 @@ func (c *ServerCommand) Run(args []string) int {
return 1
}
if config.Storage.Type == "raft" {
if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
config.ClusterAddr = envCA
@ -1066,6 +1396,9 @@ CLUSTER_SYNTHESIS_COMPLETE:
info["cgo"] = "enabled"
}
infoKeys = append(infoKeys, "recovery mode")
info["recovery mode"] = "false"
// Server configuration output
padding := 24
sort.Strings(infoKeys)
@ -1263,6 +1596,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
MaxRequestDuration: ln.maxRequestDuration,
DisablePrintableCheck: config.DisablePrintableCheck,
UnauthenticatedMetricsAccess: ln.unauthenticatedMetricsAccess,
RecoveryMode: c.flagRecovery,
})
// We perform validation on the config earlier, we can just cast here

View file

@ -19,16 +19,26 @@ import (
"github.com/mitchellh/go-testing-interface"
)
type GenerateRootKind int
const (
GenerateRootRegular GenerateRootKind = iota
GenerateRootDR
GenerateRecovery
)
// Generates a root token on the target cluster.
func GenerateRoot(t testing.T, cluster *vault.TestCluster, drToken bool) string {
token, err := GenerateRootWithError(t, cluster, drToken)
func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string {
t.Helper()
token, err := GenerateRootWithError(t, cluster, kind)
if err != nil {
t.Fatal(err)
}
return token
}
func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool) (string, error) {
func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) (string, error) {
t.Helper()
// If recovery keys supported, use those to perform root token generation instead
var keys [][]byte
if cluster.Cores[0].SealAccess().RecoveryKeySupported() {
@ -36,13 +46,18 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool
} else {
keys = cluster.BarrierKeys
}
client := cluster.Cores[0].Client
f := client.Sys().GenerateRootInit
if drToken {
f = client.Sys().GenerateDROperationTokenInit
var err error
var status *api.GenerateRootStatusResponse
switch kind {
case GenerateRootRegular:
status, err = client.Sys().GenerateRootInit("", "")
case GenerateRootDR:
status, err = client.Sys().GenerateDROperationTokenInit("", "")
case GenerateRecovery:
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
}
status, err := f("", "")
if err != nil {
return "", err
}
@ -57,11 +72,16 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool
if i >= status.Required {
break
}
f := client.Sys().GenerateRootUpdate
if drToken {
f = client.Sys().GenerateDROperationTokenUpdate
strKey := base64.StdEncoding.EncodeToString(key)
switch kind {
case GenerateRootRegular:
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
case GenerateRootDR:
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
case GenerateRecovery:
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
}
status, err = f(base64.StdEncoding.EncodeToString(key), status.Nonce)
if err != nil {
return "", err
}

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/NYTimes/gziphandler"
"io"
"io/ioutil"
"net"
@ -16,11 +17,10 @@ import (
"strings"
"time"
"github.com/NYTimes/gziphandler"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
@ -111,9 +111,15 @@ func Handler(props *vault.HandlerProperties) http.Handler {
// Create the muxer to handle the actual endpoints
mux := http.NewServeMux()
// Handle non-forwarded paths
mux.Handle("/v1/sys/config/state/", handleLogicalNoForward(core))
mux.Handle("/v1/sys/host-info", handleLogicalNoForward(core))
switch {
case props.RecoveryMode:
raw := vault.NewRawBackend(core)
strategy := vault.GenerateRecoveryTokenStrategy(props.RecoveryToken)
mux.Handle("/v1/sys/raw/", handleLogicalRecovery(raw, props.RecoveryToken))
mux.Handle("/v1/sys/generate-recovery-token/attempt", handleSysGenerateRootAttempt(core, strategy))
mux.Handle("/v1/sys/generate-recovery-token/update", handleSysGenerateRootUpdate(core, strategy))
default:
// Handle pprof paths
mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core))
mux.Handle("/v1/sys/init", handleSysInit(core))
@ -146,6 +152,7 @@ func Handler(props *vault.HandlerProperties) http.Handler {
}
mux.Handle("/ui", handleUIRedirect())
mux.Handle("/", handleUIRedirect())
}
// Register metrics path without authentication if enabled
@ -154,6 +161,7 @@ func Handler(props *vault.HandlerProperties) http.Handler {
}
additionalRoutes(mux, core)
}
// Wrap the handler in another handler to trigger all help paths.
helpWrappedHandler := wrapHelpHandler(mux, core)
@ -489,7 +497,7 @@ func parseQuery(values url.Values) map[string]interface{} {
return nil
}
func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
reader := r.Body
@ -505,7 +513,7 @@ func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out
}
}
var origBody io.ReadWriter
if core.PerfStandby() {
if perfStandby {
// Since we're checking PerfStandby here we key on origBody being nil
// or not later, so we need to always allocate so it's non-nil
origBody = new(bytes.Buffer)

View file

@ -4,6 +4,8 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/hashicorp/vault/sdk/helper/consts"
"go.uber.org/atomic"
"io"
"net"
"net/http"
@ -18,7 +20,7 @@ import (
"github.com/hashicorp/vault/vault"
)
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
ns, err := namespace.FromContext(r.Context())
if err != nil {
return nil, nil, http.StatusBadRequest, nil
@ -78,7 +80,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
passHTTPReq = true
origBody = r.Body
} else {
origBody, err = parseRequest(core, r, w, &data)
origBody, err = parseRequest(perfStandby, r, w, &data)
if err == io.EOF {
data = nil
err = nil
@ -105,14 +107,32 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
}
req, err := requestAuth(core, r, &logical.Request{
req := &logical.Request{
ID: request_id,
Operation: op,
Path: path,
Data: data,
Connection: getConnection(r),
Headers: r.Header,
})
}
if passHTTPReq {
req.HTTPRequest = r
}
if responseWriter != nil {
req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter)
}
return req, origBody, 0, nil
}
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
if err != nil {
return nil, nil, status, err
}
req, err = requestAuth(core, r, req)
if err != nil {
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
return nil, nil, http.StatusForbidden, nil
@ -135,12 +155,6 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
}
if passHTTPReq {
req.HTTPRequest = r
}
if responseWriter != nil {
req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter)
}
return req, origBody, 0, nil
}
@ -168,6 +182,32 @@ func handleLogicalNoForward(core *vault.Core) http.Handler {
return handleLogicalInternal(core, false, true)
}
func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
}
reqToken := r.Header.Get(consts.AuthHeaderName)
if reqToken == "" || token.Load() == "" || reqToken != token.Load() {
respondError(w, http.StatusForbidden, nil)
}
resp, err := raw.HandleRequest(r.Context(), req)
if respondErrorCommon(w, req, resp, err) {
return
}
var httpResp *logical.HTTPResponse
if resp != nil {
httpResp = logical.LogicalResponseToHTTPResponse(resp)
httpResp.RequestID = req.ID
}
respondOk(w, httpResp)
})
}
// handleLogicalInternal is a common helper that returns a handler for
// processing logical requests. The behavior depends on the various boolean
// toggles. Refer to usage on functions for possible behaviors.

View file

@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r
func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) {
// Parse the request
var req GenerateRootInitRequest
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
respondError(w, http.StatusBadRequest, err)
return
}
@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse the request
var req GenerateRootUpdateRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View file

@ -40,7 +40,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request)
// Parse the request
var req InitRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View file

@ -25,7 +25,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler {
func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req JoinRequest
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
respondError(w, http.StatusBadRequest, err)
return
}

View file

@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool,
func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {
// Parse the request
var req RekeyUpdateRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery
func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyVerificationUpdateRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View file

@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
// Parse the request
var req UnsealRequest
if _, err := parseRequest(core, r, w, &req); err != nil {
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/armon/go-metrics"
"io"
"io/ioutil"
"os"
@ -12,12 +13,11 @@ import (
"sync"
"time"
metrics "github.com/armon/go-metrics"
proto "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-raftchunking"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/raft"
snapshot "github.com/hashicorp/raft-snapshot"
raftboltdb "github.com/hashicorp/vault/physical/raft/logstore"
@ -367,6 +367,26 @@ type SetupOpts struct {
// StartAsLeader is used to specify this node should start as leader and
// bypass the leader election. This should be used with caution.
StartAsLeader bool
// RecoveryModeConfig is the configuration for the raft cluster in recovery
// mode.
RecoveryModeConfig *raft.Configuration
}
func (b *RaftBackend) StartRecoveryCluster(ctx context.Context, peer Peer) error {
recoveryModeConfig := &raft.Configuration{
Servers: []raft.Server{
{
ID: raft.ServerID(peer.ID),
Address: raft.ServerAddress(peer.Address),
},
},
}
return b.SetupCluster(context.Background(), SetupOpts{
StartAsLeader: true,
RecoveryModeConfig: recoveryModeConfig,
})
}
// SetupCluster starts the raft cluster and enables the networking needed for
@ -477,6 +497,13 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error {
b.logger.Info("raft recovery deleted peers.json")
}
if opts.RecoveryModeConfig != nil {
err = raft.RecoverCluster(raftConfig, b.fsm, b.logStore, b.stableStore, b.snapStore, b.raftTransport, *opts.RecoveryModeConfig)
if err != nil {
return errwrap.Wrapf("recovering raft cluster failed: {{err}}", err)
}
}
raftObj, err := raft.NewRaft(raftConfig, b.fsm.chunker, b.logStore, b.stableStore, b.snapStore, b.raftTransport)
b.fsm.SetNoopRestore(false)
if err != nil {

View file

@ -423,6 +423,10 @@ type Core struct {
// Stores any funcs that should be run on successful postUnseal
postUnsealFuncs []func()
// Stores any funcs that should be run on successful barrier unseal in
// recovery mode
postRecoveryUnsealFuncs []func() error
// replicationFailure is used to mark when replication has entered an
// unrecoverable failure.
replicationFailure *uint32
@ -465,6 +469,8 @@ type Core struct {
rawConfig *server.Config
coreNumber int
recoveryMode bool
}
// CoreConfig is used to parameterize a core
@ -542,6 +548,8 @@ type CoreConfig struct {
MetricsHelper *metricsutil.MetricsHelper
CounterSyncInterval time.Duration
RecoveryMode bool
}
func (c *CoreConfig) Clone() *CoreConfig {
@ -668,6 +676,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
requests: new(uint64),
syncInterval: syncInterval,
},
recoveryMode: conf.RecoveryMode,
}
atomic.StoreUint32(c.sealed, 1)
@ -726,25 +735,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
var err error
if conf.PluginDirectory != "" {
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
if err != nil {
return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err)
}
}
// Construct a new AES-GCM barrier
c.barrier, err = NewAESGCMBarrier(c.physical)
if err != nil {
return nil, errwrap.Wrapf("barrier setup failed: {{err}}", err)
}
createSecondaries(c, conf)
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical
}
// We create the funcs here, then populate the given config with it so that
// the caller can share state
conf.ReloadFuncsLock = &c.reloadFuncsLock
@ -753,6 +749,25 @@ func NewCore(conf *CoreConfig) (*Core, error) {
c.reloadFuncsLock.Unlock()
conf.ReloadFuncs = &c.reloadFuncs
// All the things happening below this are not required in
// recovery mode
if c.recoveryMode {
return c, nil
}
if conf.PluginDirectory != "" {
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
if err != nil {
return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err)
}
}
createSecondaries(c, conf)
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical
}
logicalBackends := make(map[string]logical.Factory)
for k, f := range conf.LogicalBackends {
logicalBackends[k] = f

View file

@ -218,7 +218,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
} else {
// We haven't finished, so generating a root token should still be the
// old keys (which are still currently set)
testhelpers.GenerateRoot(t, cluster, false)
testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular)
}
// Provide the final new key
@ -256,7 +256,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
}
} else {
// The old keys should no longer work
_, err := testhelpers.GenerateRootWithError(t, cluster, false)
_, err := testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRootRegular)
if err == nil {
t.Fatal("expected error")
}
@ -273,6 +273,6 @@ func testSysRekey_Verification(t *testing.T, recovery bool) {
if err := client.Sys().GenerateRootCancel(); err != nil {
t.Fatal(err)
}
testhelpers.GenerateRoot(t, cluster, false)
testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular)
}
}

View file

@ -1,4 +1,4 @@
package token
package misc
import (
"bytes"

View file

@ -1,4 +1,4 @@
package token
package misc
import (
"testing"

View file

@ -0,0 +1,140 @@
package misc
import (
"github.com/go-test/deep"
"go.uber.org/atomic"
"path"
"testing"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/hashicorp/vault/vault"
)
func TestRecovery(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Debug).Named(t.Name())
inm, err := inmem.NewInmemHA(nil, logger)
if err != nil {
t.Fatal(err)
}
var keys [][]byte
var secretUUID string
var rootToken string
{
conf := vault.CoreConfig{
Physical: inm,
Logger: logger,
}
opts := vault.TestClusterOptions{
HandlerFunc: http.Handler,
NumCores: 1,
}
cluster := vault.NewTestCluster(t, &conf, &opts)
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
rootToken = client.Token()
var fooVal = map[string]interface{}{"bar": 1.0}
_, err = client.Logical().Write("secret/foo", fooVal)
if err != nil {
t.Fatal(err)
}
secret, err := client.Logical().List("secret/")
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 {
t.Fatalf("got=%v, want=%v, diff: %v", secret.Data["keys"], []string{"foo"}, diff)
}
mounts, err := cluster.Cores[0].Client.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
secretMount := mounts["secret/"]
if secretMount == nil {
t.Fatalf("secret mount not found, mounts: %v", mounts)
}
secretUUID = secretMount.UUID
cluster.EnsureCoresSealed(t)
keys = cluster.BarrierKeys
}
{
// Now bring it up in recovery mode.
var tokenRef atomic.String
conf := vault.CoreConfig{
Physical: inm,
Logger: logger,
RecoveryMode: true,
}
opts := vault.TestClusterOptions{
HandlerFunc: http.Handler,
NumCores: 1,
SkipInit: true,
DefaultHandlerProperties: vault.HandlerProperties{
RecoveryMode: true,
RecoveryToken: &tokenRef,
},
}
cluster := vault.NewTestCluster(t, &conf, &opts)
cluster.BarrierKeys = keys
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
recoveryToken := testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRecovery)
_, err = testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRecovery)
if err == nil {
t.Fatal("expected second generate-root to fail")
}
client.SetToken(recoveryToken)
secret, err := client.Logical().List(path.Join("sys/raw/logical", secretUUID))
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 {
t.Fatalf("got=%v, want=%v, diff: %v", secret.Data, []string{"foo"}, diff)
}
_, err = client.Logical().Delete(path.Join("sys/raw/logical", secretUUID, "foo"))
if err != nil {
t.Fatal(err)
}
cluster.EnsureCoresSealed(t)
}
{
// Now go back to regular mode and verify that our changes are present
conf := vault.CoreConfig{
Physical: inm,
Logger: logger,
}
opts := vault.TestClusterOptions{
HandlerFunc: http.Handler,
NumCores: 1,
SkipInit: true,
}
cluster := vault.NewTestCluster(t, &conf, &opts)
cluster.BarrierKeys = keys
cluster.Start()
testhelpers.EnsureCoresUnsealed(t, cluster)
defer cluster.Cleanup()
client := cluster.Cores[0].Client
client.SetToken(rootToken)
secret, err := client.Logical().List("secret/")
if err != nil {
t.Fatal(err)
}
if secret != nil {
t.Fatal("expected no data in secret mount")
}
}
}

View file

@ -4,10 +4,10 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/helper/xor"
"github.com/hashicorp/vault/sdk/helper/consts"
@ -77,10 +77,10 @@ type GenerateRootResult struct {
func (c *Core) GenerateRootProgress() (int, error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.Sealed() {
if c.Sealed() && !c.recoveryMode {
return 0, consts.ErrSealed
}
if c.standby {
if c.standby && !c.recoveryMode {
return 0, consts.ErrStandby
}
@ -95,10 +95,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.Sealed() {
if c.Sealed() && !c.recoveryMode {
return nil, consts.ErrSealed
}
if c.standby {
if c.standby && !c.recoveryMode {
return nil, consts.ErrStandby
}
@ -141,10 +141,17 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.Sealed() {
if c.Sealed() && !c.recoveryMode {
return consts.ErrSealed
}
if c.standby {
barrierSealed, err := c.barrier.Sealed()
if err != nil {
return errors.New("unable to check barrier seal status")
}
if !barrierSealed && c.recoveryMode {
return errors.New("attempt to generate recovery operation token when already unsealed")
}
if c.standby && !c.recoveryMode {
return consts.ErrStandby
}
@ -174,6 +181,8 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg
switch strategy.(type) {
case generateStandardRootToken:
c.logger.Info("root generation initialized", "nonce", c.generateRootConfig.Nonce)
case *generateRecoveryToken:
c.logger.Info("recovery operation token generation initialized", "nonce", c.generateRootConfig.Nonce)
default:
c.logger.Info("dr operation token generation initialized", "nonce", c.generateRootConfig.Nonce)
}
@ -217,10 +226,19 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
// Ensure we are already unsealed
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.Sealed() {
if c.Sealed() && !c.recoveryMode {
return nil, consts.ErrSealed
}
if c.standby {
barrierSealed, err := c.barrier.Sealed()
if err != nil {
return nil, errors.New("unable to check barrier seal status")
}
if !barrierSealed && c.recoveryMode {
return nil, errors.New("attempt to generate recovery operation token when already unsealed")
}
if c.standby && !c.recoveryMode {
return nil, consts.ErrStandby
}
@ -263,31 +281,82 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
}, nil
}
// Recover the master key
var masterKey []byte
// Combine the key parts
var combinedKey []byte
if config.SecretThreshold == 1 {
masterKey = c.generateRootProgress[0]
combinedKey = c.generateRootProgress[0]
c.generateRootProgress = nil
} else {
masterKey, err = shamir.Combine(c.generateRootProgress)
combinedKey, err = shamir.Combine(c.generateRootProgress)
c.generateRootProgress = nil
if err != nil {
return nil, errwrap.Wrapf("failed to compute master key: {{err}}", err)
}
}
// Verify the master key
if c.seal.RecoveryKeySupported() {
if err := c.seal.VerifyRecoveryKey(ctx, masterKey); err != nil {
switch {
case c.seal.RecoveryKeySupported():
// Ensure that the combined recovery key is valid
if err := c.seal.VerifyRecoveryKey(ctx, combinedKey); err != nil {
c.logger.Error("root generation aborted, recovery key verification failed", "error", err)
return nil, err
}
} else {
if err := c.barrier.VerifyMaster(masterKey); err != nil {
// If we are in recovery mode, then retrieve
// the stored keys and unseal the barrier
if c.recoveryMode {
if !c.seal.StoredKeysSupported() {
c.logger.Error("root generation aborted, recovery key verified but stored keys unsupported")
return nil, errors.New("recovery key verified but stored keys unsupported")
}
masterKeyShares, err := c.seal.GetStoredKeys(ctx)
if err != nil {
return nil, errwrap.Wrapf("unable to retrieve stored keys in recovery mode: {{err}}", err)
}
switch len(masterKeyShares) {
case 0:
return nil, errors.New("seal returned no master key shares in recovery mode")
case 1:
combinedKey = masterKeyShares[0]
default:
combinedKey, err = shamir.Combine(masterKeyShares)
if err != nil {
return nil, errwrap.Wrapf("failed to compute master key in recovery mode: {{err}}", err)
}
}
// Use the retrieved master key to unseal the barrier
if err := c.barrier.Unseal(ctx, combinedKey); err != nil {
c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err)
return nil, err
}
}
default:
switch {
case c.recoveryMode:
// If we are in recovery mode, being able to unseal
// the barrier is how we establish authentication
if err := c.barrier.Unseal(ctx, combinedKey); err != nil {
c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err)
return nil, err
}
default:
if err := c.barrier.VerifyMaster(combinedKey); err != nil {
c.logger.Error("root generation aborted, master key verification failed", "error", err)
return nil, err
}
}
}
// Authentication in recovery mode is successful
if c.recoveryMode {
// Run any post unseal functions that are set
for _, v := range c.postRecoveryUnsealFuncs {
if err := v(); err != nil {
return nil, errwrap.Wrapf("failed to run post unseal func: {{err}}", err)
}
}
}
// Run the generate strategy
token, cleanupFunc, err := strategy.generate(ctx, c)
@ -334,14 +403,12 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
switch strategy.(type) {
case generateStandardRootToken:
if c.logger.IsInfo() {
c.logger.Info("root generation finished", "nonce", c.generateRootConfig.Nonce)
}
case *generateRecoveryToken:
c.logger.Info("recovery operation token generation finished", "nonce", c.generateRootConfig.Nonce)
default:
if c.logger.IsInfo() {
c.logger.Info("dr operation token generation finished", "nonce", c.generateRootConfig.Nonce)
}
}
c.generateRootProgress = nil
c.generateRootConfig = nil
@ -352,10 +419,10 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string,
func (c *Core) GenerateRootCancel() error {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.Sealed() {
if c.Sealed() && !c.recoveryMode {
return consts.ErrSealed
}
if c.standby {
if c.standby && !c.recoveryMode {
return consts.ErrStandby
}

View file

@ -0,0 +1,31 @@
package vault
import (
"context"
"github.com/hashicorp/vault/sdk/helper/base62"
"go.uber.org/atomic"
)
// GenerateRecoveryTokenStrategy is the strategy used to generate a
// recovery token
func GenerateRecoveryTokenStrategy(token *atomic.String) GenerateRootStrategy {
return &generateRecoveryToken{token: token}
}
// generateRecoveryToken implements the GenerateRootStrategy and is in
// charge of creating recovery tokens.
type generateRecoveryToken struct {
token *atomic.String
}
func (g *generateRecoveryToken) generate(ctx context.Context, c *Core) (string, func(), error) {
id, err := base62.Random(TokenLength)
if err != nil {
return "", nil, err
}
token := "r." + id
g.token.Store(token)
return token, func() { g.token.Store("") }, nil
}

View file

@ -38,6 +38,31 @@ var (
initInProgress uint32
)
func (c *Core) InitializeRecovery(ctx context.Context) error {
if !c.recoveryMode {
return nil
}
raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend)
if !ok {
return nil
}
parsedClusterAddr, err := url.Parse(c.ClusterAddr())
if err != nil {
return err
}
c.postRecoveryUnsealFuncs = append(c.postRecoveryUnsealFuncs, func() error {
return raftStorage.StartRecoveryCluster(context.Background(), raft.Peer{
ID: raftStorage.NodeID(),
Address: parsedClusterAddr.Host,
})
})
return nil
}
// Initialized checks if the Vault is already initialized
func (c *Core) Initialized(ctx context.Context) (bool, error) {
// Check the barrier first

216
vault/logical_raw.go Normal file
View file

@ -0,0 +1,216 @@
package vault
import (
"context"
"fmt"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/compressutil"
"github.com/hashicorp/vault/sdk/logical"
"strings"
)
var (
// protectedPaths cannot be accessed via the raw APIs.
// This is both for security and to prevent disrupting Vault.
protectedPaths = []string{
keyringPath,
// Changing the cluster info path can change the cluster ID which can be disruptive
coreLocalClusterInfoPath,
}
)
type RawBackend struct {
*framework.Backend
barrier SecurityBarrier
logger log.Logger
checkRaw func(path string) error
recoveryMode bool
}
func NewRawBackend(core *Core) *RawBackend {
r := &RawBackend{
barrier: core.barrier,
logger: core.logger.Named("raw"),
checkRaw: func(path string) error {
return nil
},
recoveryMode: core.recoveryMode,
}
r.Backend = &framework.Backend{
Paths: rawPaths("sys/", r),
}
return r
}
// handleRawRead is used to read directly from the barrier
func (b *RawBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
if b.recoveryMode {
b.logger.Info("reading", "path", path)
}
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot read '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
// Run additional checks if needed
if err := b.checkRaw(path); err != nil {
b.logger.Warn(err.Error(), "path", path)
return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest
}
entry, err := b.barrier.Get(ctx, path)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
if entry == nil {
return nil, nil
}
// Run this through the decompression helper to see if it's been compressed.
// If the input contained the compression canary, `outputBytes` will hold
// the decompressed data. If the input was not compressed, then `outputBytes`
// will be nil.
outputBytes, _, err := compressutil.Decompress(entry.Value)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
// `outputBytes` is nil if the input is uncompressed. In that case set it to the original input.
if outputBytes == nil {
outputBytes = entry.Value
}
resp := &logical.Response{
Data: map[string]interface{}{
"value": string(outputBytes),
},
}
return resp, nil
}
// handleRawWrite is used to write directly to the barrier
func (b *RawBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
if b.recoveryMode {
b.logger.Info("writing", "path", path)
}
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot write '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
value := data.Get("value").(string)
entry := &logical.StorageEntry{
Key: path,
Value: []byte(value),
}
if err := b.barrier.Put(ctx, entry); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
return nil, nil
}
// handleRawDelete is used to delete directly from the barrier
func (b *RawBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
if b.recoveryMode {
b.logger.Info("deleting", "path", path)
}
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot delete '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
if err := b.barrier.Delete(ctx, path); err != nil {
return handleErrorNoReadOnlyForward(err)
}
return nil, nil
}
// handleRawList is used to list directly from the barrier
func (b *RawBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
if path != "" && !strings.HasSuffix(path, "/") {
path = path + "/"
}
if b.recoveryMode {
b.logger.Info("listing", "path", path)
}
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot list '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
// Run additional checks if needed
if err := b.checkRaw(path); err != nil {
b.logger.Warn(err.Error(), "path", path)
return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest
}
keys, err := b.barrier.List(ctx, path)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
return logical.ListResponse(keys), nil
}
func rawPaths(prefix string, r *RawBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: prefix + "(raw/?$|raw/(?P<path>.+))",
Fields: map[string]*framework.FieldSchema{
"path": &framework.FieldSchema{
Type: framework.TypeString,
},
"value": &framework.FieldSchema{
Type: framework.TypeString,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: r.handleRawRead,
Summary: "Read the value of the key at the given path.",
},
logical.UpdateOperation: &framework.PathOperation{
Callback: r.handleRawWrite,
Summary: "Update the value of the key at the given path.",
},
logical.DeleteOperation: &framework.PathOperation{
Callback: r.handleRawDelete,
Summary: "Delete the key with given path.",
},
logical.ListOperation: &framework.PathOperation{
Callback: r.handleRawList,
Summary: "Return a list keys for a given path prefix.",
},
},
HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]),
HelpDescription: strings.TrimSpace(sysHelp["raw"][1]),
},
}
}

View file

@ -23,14 +23,13 @@ import (
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
memdb "github.com/hashicorp/go-memdb"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/hostutil"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/compressutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
@ -40,16 +39,6 @@ import (
"github.com/mitchellh/mapstructure"
)
var (
// protectedPaths cannot be accessed via the raw APIs.
// This is both for security and to prevent disrupting Vault.
protectedPaths = []string{
keyringPath,
// Changing the cluster info path can change the cluster ID which can be disruptive
coreLocalClusterInfoPath,
}
)
const maxBytes = 128 * 1024
func systemBackendMemDBSchema() *memdb.DBSchema {
@ -172,40 +161,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath())
if core.rawEnabled {
b.Backend.Paths = append(b.Backend.Paths, &framework.Path{
Pattern: "(raw/?$|raw/(?P<path>.+))",
Fields: map[string]*framework.FieldSchema{
"path": &framework.FieldSchema{
Type: framework.TypeString,
},
"value": &framework.FieldSchema{
Type: framework.TypeString,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.handleRawRead,
Summary: "Read the value of the key at the given path.",
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.handleRawWrite,
Summary: "Update the value of the key at the given path.",
},
logical.DeleteOperation: &framework.PathOperation{
Callback: b.handleRawDelete,
Summary: "Delete the key with given path.",
},
logical.ListOperation: &framework.PathOperation{
Callback: b.handleRawList,
Summary: "Return a list keys for a given path prefix.",
},
},
HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]),
HelpDescription: strings.TrimSpace(sysHelp["raw"][1]),
})
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
}
if _, ok := core.underlyingPhysical.(*raft.RaftBackend); ok {
@ -216,6 +172,17 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
return b
}
func (b *SystemBackend) rawPaths() []*framework.Path {
r := &RawBackend{
barrier: b.Core.barrier,
logger: b.logger,
checkRaw: func(path string) error {
return checkRaw(b, path)
},
}
return rawPaths("", r)
}
// SystemBackend implements logical.Backend and is used to interact with
// the core of the system. This backend is hardcoded to exist at the "sys"
// prefix. Conceptually it is similar to procfs on Linux.
@ -2248,123 +2215,6 @@ func (b *SystemBackend) handleConfigUIHeadersDelete(ctx context.Context, req *lo
return nil, nil
}
// handleRawRead is used to read directly from the barrier
func (b *SystemBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot read '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
// Run additional checks if needed
if err := checkRaw(b, path); err != nil {
b.Core.logger.Warn(err.Error(), "path", path)
return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest
}
entry, err := b.Core.barrier.Get(ctx, path)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
if entry == nil {
return nil, nil
}
// Run this through the decompression helper to see if it's been compressed.
// If the input contained the compression canary, `outputBytes` will hold
// the decompressed data. If the input was not compressed, then `outputBytes`
// will be nil.
outputBytes, _, err := compressutil.Decompress(entry.Value)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
// `outputBytes` is nil if the input is uncompressed. In that case set it to the original input.
if outputBytes == nil {
outputBytes = entry.Value
}
resp := &logical.Response{
Data: map[string]interface{}{
"value": string(outputBytes),
},
}
return resp, nil
}
// handleRawWrite is used to write directly to the barrier
func (b *SystemBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot write '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
value := data.Get("value").(string)
entry := &logical.StorageEntry{
Key: path,
Value: []byte(value),
}
if err := b.Core.barrier.Put(ctx, entry); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
return nil, nil
}
// handleRawDelete is used to delete directly from the barrier
func (b *SystemBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot delete '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
if err := b.Core.barrier.Delete(ctx, path); err != nil {
return handleErrorNoReadOnlyForward(err)
}
return nil, nil
}
// handleRawList is used to list directly from the barrier
func (b *SystemBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
if path != "" && !strings.HasSuffix(path, "/") {
path = path + "/"
}
// Prevent access of protected paths
for _, p := range protectedPaths {
if strings.HasPrefix(path, p) {
err := fmt.Sprintf("cannot list '%s'", path)
return logical.ErrorResponse(err), logical.ErrInvalidRequest
}
}
// Run additional checks if needed
if err := checkRaw(b, path); err != nil {
b.Core.logger.Warn(err.Error(), "path", path)
return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest
}
keys, err := b.Core.barrier.List(ctx, path)
if err != nil {
return handleErrorNoReadOnlyForward(err)
}
return logical.ListResponse(keys), nil
}
// handleKeyStatus returns status information about the backend key
func (b *SystemBackend) handleKeyStatus(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Get the key info

View file

@ -22,6 +22,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/helper/wrapping"
"github.com/hashicorp/vault/sdk/logical"
uberAtomic "go.uber.org/atomic"
)
const (
@ -43,6 +44,8 @@ type HandlerProperties struct {
MaxRequestSize int64
MaxRequestDuration time.Duration
DisablePrintableCheck bool
RecoveryMode bool
RecoveryToken *uberAtomic.String
UnauthenticatedMetricsAccess bool
}

View file

@ -1042,6 +1042,7 @@ type TestClusterOptions struct {
KeepStandbysSealed bool
SkipInit bool
HandlerFunc func(*HandlerProperties) http.Handler
DefaultHandlerProperties HandlerProperties
BaseListenAddress string
NumCores int
SealFunc func() Seal
@ -1417,7 +1418,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
coreConfig.DevToken = base.DevToken
coreConfig.CounterSyncInterval = base.CounterSyncInterval
coreConfig.RecoveryMode = base.RecoveryMode
}
if coreConfig.RawConfig == nil {
@ -1511,10 +1512,12 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
cores = append(cores, c)
coreConfigs = append(coreConfigs, &localConfig)
if opts != nil && opts.HandlerFunc != nil {
handlers[i] = opts.HandlerFunc(&HandlerProperties{
Core: c,
MaxRequestDuration: DefaultMaxRequestDuration,
})
props := opts.DefaultHandlerProperties
props.Core = c
if props.MaxRequestDuration == 0 {
props.MaxRequestDuration = DefaultMaxRequestDuration
}
handlers[i] = opts.HandlerFunc(&props)
servers[i].Handler = handlers[i]
}