Add logic to skip initialization in some cases and some invalidation logic

This commit is contained in:
Jeff Mitchell 2017-05-05 15:01:52 -04:00
parent 5f99def25b
commit 6f6f242061
12 changed files with 75 additions and 29 deletions

View file

@ -72,6 +72,8 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
AuthRenew: b.pathLoginRenew, AuthRenew: b.pathLoginRenew,
Init: b.initialize, Init: b.initialize,
Invalidate: b.invalidate,
} }
b.view = conf.StorageView b.view = conf.StorageView
@ -91,6 +93,7 @@ type backend struct {
func (b *backend) initialize() error { func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{ salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA1Hash, HashFunc: salt.SHA1Hash,
Location: salt.DefaultLocation,
}) })
if err != nil { if err != nil {
return err return err
@ -174,6 +177,14 @@ func (b *backend) upgradeToSalted(view logical.Storage) error {
return nil return nil
} }
func (b *backend) invalidate(key string) {
switch key {
case salt.DefaultLocation:
// reread the salt
b.initialize()
}
}
const backendHelp = ` const backendHelp = `
The App ID credential provider is used to perform authentication from The App ID credential provider is used to perform authentication from
within applications or machine by pairing together two hard-to-guess within applications or machine by pairing together two hard-to-guess

View file

@ -14,7 +14,8 @@ type backend struct {
// The salt value to be used by the information to be accessed only // The salt value to be used by the information to be accessed only
// by this backend. // by this backend.
salt *salt.Salt salt *salt.Salt
saltMutex sync.RWMutex
// The view to use when creating the salt // The view to use when creating the salt
view logical.Storage view logical.Storage
@ -92,14 +93,18 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathTidySecretID(b), pathTidySecretID(b),
}, },
), ),
Init: b.initialize, Init: b.initialize,
Invalidate: b.invalidate,
} }
return b, nil return b, nil
} }
func (b *backend) initialize() error { func (b *backend) initialize() error {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
salt, err := salt.NewSalt(b.view, &salt.Config{ salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash, HashFunc: salt.SHA256Hash,
Location: salt.DefaultLocation,
}) })
if err != nil { if err != nil {
return err return err
@ -108,6 +113,14 @@ func (b *backend) initialize() error {
return nil return nil
} }
func (b *backend) invalidate(key string) {
switch key {
case salt.DefaultLocation:
// reread the salt
b.initialize()
}
}
// periodicFunc of the backend will be invoked once a minute by the RollbackManager. // periodicFunc of the backend will be invoked once a minute by the RollbackManager.
// RoleRole backend utilizes this function to delete expired SecretID entries. // RoleRole backend utilizes this function to delete expired SecretID entries.
// This could mean that the SecretID may live in the backend upto 1 min after its // This could mean that the SecretID may live in the backend upto 1 min after its

View file

@ -1939,7 +1939,9 @@ func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
b.saltMutex.RLock()
entryIndex := "role_id/" + b.salt.SaltID(roleID) entryIndex := "role_id/" + b.salt.SaltID(roleID)
b.saltMutex.RUnlock()
entry, err := logical.StorageEntryJSON(entryIndex, roleIDEntry) entry, err := logical.StorageEntryJSON(entryIndex, roleIDEntry)
if err != nil { if err != nil {
@ -1963,7 +1965,9 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
var result roleIDStorageEntry var result roleIDStorageEntry
b.saltMutex.RLock()
entryIndex := "role_id/" + b.salt.SaltID(roleID) entryIndex := "role_id/" + b.salt.SaltID(roleID)
b.saltMutex.RUnlock()
if entry, err := s.Get(entryIndex); err != nil { if entry, err := s.Get(entryIndex); err != nil {
return nil, err return nil, err
@ -1987,7 +1991,9 @@ func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
b.saltMutex.RLock()
entryIndex := "role_id/" + b.salt.SaltID(roleID) entryIndex := "role_id/" + b.salt.SaltID(roleID)
b.saltMutex.RUnlock()
return s.Delete(entryIndex) return s.Delete(entryIndex)
} }

View file

@ -469,7 +469,9 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
var result secretIDAccessorStorageEntry var result secretIDAccessorStorageEntry
// Create index entry, mapping the accessor to the token ID // Create index entry, mapping the accessor to the token ID
b.saltMutex.RLock()
entryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor) entryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
b.saltMutex.RUnlock()
accessorLock := b.secretIDAccessorLock(secretIDAccessor) accessorLock := b.secretIDAccessorLock(secretIDAccessor)
accessorLock.RLock() accessorLock.RLock()
@ -498,7 +500,9 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
entry.SecretIDAccessor = accessorUUID entry.SecretIDAccessor = accessorUUID
// Create index entry, mapping the accessor to the token ID // Create index entry, mapping the accessor to the token ID
b.saltMutex.RLock()
entryIndex := "accessor/" + b.salt.SaltID(entry.SecretIDAccessor) entryIndex := "accessor/" + b.salt.SaltID(entry.SecretIDAccessor)
b.saltMutex.RUnlock()
accessorLock := b.secretIDAccessorLock(accessorUUID) accessorLock := b.secretIDAccessorLock(accessorUUID)
accessorLock.Lock() accessorLock.Lock()
@ -517,7 +521,9 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID. // deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error { func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
b.saltMutex.RLock()
accessorEntryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor) accessorEntryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
b.saltMutex.RUnlock()
accessorLock := b.secretIDAccessorLock(secretIDAccessor) accessorLock := b.secretIDAccessorLock(secretIDAccessor)
accessorLock.Lock() accessorLock.Lock()

View file

@ -6,7 +6,6 @@ import (
"github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -21,7 +20,6 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
type backend struct { type backend struct {
*framework.Backend *framework.Backend
Salt *salt.Salt
// Used during initialization to set the salt // Used during initialization to set the salt
view logical.Storage view logical.Storage
@ -105,24 +103,11 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
}, },
Invalidate: b.invalidate, Invalidate: b.invalidate,
Init: b.initialize,
} }
return b, nil return b, nil
} }
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.Salt = salt
return nil
}
// periodicFunc performs the tasks that the backend wishes to do periodically. // periodicFunc performs the tasks that the backend wishes to do periodically.
// Currently this will be triggered once in a minute by the RollbackManager. // Currently this will be triggered once in a minute by the RollbackManager.
// //

View file

@ -2,6 +2,7 @@ package ssh
import ( import (
"strings" "strings"
"sync"
"github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@ -10,8 +11,9 @@ import (
type backend struct { type backend struct {
*framework.Backend *framework.Backend
view logical.Storage view logical.Storage
salt *salt.Salt salt *salt.Salt
saltMutex sync.RWMutex
} }
func Factory(conf *logical.BackendConfig) (logical.Backend, error) { func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
@ -57,14 +59,19 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
secretOTP(&b), secretOTP(&b),
}, },
Init: b.Initialize, Init: b.initialize,
Invalidate: b.invalidate,
} }
return &b, nil return &b, nil
} }
func (b *backend) Initialize() error { func (b *backend) initialize() error {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
salt, err := salt.NewSalt(b.view, &salt.Config{ salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash, HashFunc: salt.SHA256Hash,
Location: salt.DefaultLocation,
}) })
if err != nil { if err != nil {
return err return err
@ -73,6 +80,14 @@ func (b *backend) Initialize() error {
return nil return nil
} }
func (b *backend) invalidate(key string) {
switch key {
case salt.DefaultLocation:
// reread the salt
b.initialize()
}
}
const backendHelp = ` const backendHelp = `
The SSH backend generates credentials allowing clients to establish SSH The SSH backend generates credentials allowing clients to establish SSH
connections to remote hosts. connections to remote hosts.

View file

@ -207,6 +207,8 @@ func (b *backend) GenerateSaltedOTP() (string, string, error) {
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
b.saltMutex.RLock()
defer b.saltMutex.RUnlock()
return str, b.salt.SaltID(str), nil return str, b.salt.SaltID(str), nil
} }

View file

@ -57,7 +57,9 @@ func (b *backend) pathVerifyWrite(req *logical.Request, d *framework.FieldData)
// Create the salt of OTP because entry would have been create with the // Create the salt of OTP because entry would have been create with the
// salt and not directly of the OTP. Salt will yield the same value which // salt and not directly of the OTP. Salt will yield the same value which
// because the seed is the same, the backend salt. // because the seed is the same, the backend salt.
b.saltMutex.RLock()
otpSalted := b.salt.SaltID(otp) otpSalted := b.salt.SaltID(otp)
b.saltMutex.RUnlock()
// Return nil if there is no entry found for the OTP // Return nil if there is no entry found for the OTP
otpEntry, err := b.getOTP(req.Storage, otpSalted) otpEntry, err := b.getOTP(req.Storage, otpSalted)

View file

@ -33,6 +33,8 @@ func (b *backend) secretOTPRevoke(req *logical.Request, d *framework.FieldData)
return nil, fmt.Errorf("secret is missing internal data") return nil, fmt.Errorf("secret is missing internal data")
} }
b.saltMutex.RLock()
defer b.saltMutex.RUnlock()
err := req.Storage.Delete("otp/" + b.salt.SaltID(otp)) err := req.Storage.Delete("otp/" + b.salt.SaltID(otp))
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -42,7 +42,7 @@ var (
) )
// enableCredential is used to enable a new credential backend // enableCredential is used to enable a new credential backend
func (c *Core) enableCredential(entry *MountEntry) error { func (c *Core) enableCredential(entry *MountEntry, skipInitialization bool) error {
// Ensure we end the path in a slash // Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") { if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/" entry.Path += "/"
@ -99,8 +99,10 @@ func (c *Core) enableCredential(entry *MountEntry) error {
return fmt.Errorf("nil backend returned from %q factory", entry.Type) return fmt.Errorf("nil backend returned from %q factory", entry.Type)
} }
if err := backend.Initialize(); err != nil { if !skipInitialization {
return err if err := backend.Initialize(); err != nil {
return err
}
} }
// Update the auth table // Update the auth table

View file

@ -1210,7 +1210,7 @@ func (b *SystemBackend) handleMount(
} }
// Attempt mount // Attempt mount
if err := b.Core.mount(me); err != nil { if err := b.Core.mount(me, false); err != nil {
b.Backend.Logger().Error("sys: mount failed", "path", me.Path, "error", err) b.Backend.Logger().Error("sys: mount failed", "path", me.Path, "error", err)
return handleError(err) return handleError(err)
} }
@ -1642,7 +1642,7 @@ func (b *SystemBackend) handleEnableAuth(
} }
// Attempt enabling // Attempt enabling
if err := b.Core.enableCredential(me); err != nil { if err := b.Core.enableCredential(me, false); err != nil {
b.Backend.Logger().Error("sys: enable auth mount failed", "path", me.Path, "error", err) b.Backend.Logger().Error("sys: enable auth mount failed", "path", me.Path, "error", err)
return handleError(err) return handleError(err)
} }

View file

@ -169,7 +169,7 @@ func (e *MountEntry) Clone() *MountEntry {
} }
// Mount is used to mount a new backend to the mount table. // Mount is used to mount a new backend to the mount table.
func (c *Core) mount(entry *MountEntry) error { func (c *Core) mount(entry *MountEntry, skipInitialization bool) error {
// Ensure we end the path in a slash // Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") { if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/" entry.Path += "/"
@ -219,8 +219,10 @@ func (c *Core) mount(entry *MountEntry) error {
// Call initialize; this takes care of init tasks that must be run after // Call initialize; this takes care of init tasks that must be run after
// the ignore paths are collected // the ignore paths are collected
if err := backend.Initialize(); err != nil { if !skipInitialization {
return err if err := backend.Initialize(); err != nil {
return err
}
} }
newTable := c.mounts.shallowClone() newTable := c.mounts.shallowClone()