Add logic to skip initialization in some cases and some invalidation logic
This commit is contained in:
parent
5f99def25b
commit
6f6f242061
|
@ -72,6 +72,8 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||||
AuthRenew: b.pathLoginRenew,
|
AuthRenew: b.pathLoginRenew,
|
||||||
|
|
||||||
Init: b.initialize,
|
Init: b.initialize,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
b.view = conf.StorageView
|
b.view = conf.StorageView
|
||||||
|
@ -91,6 +93,7 @@ type backend struct {
|
||||||
func (b *backend) initialize() error {
|
func (b *backend) initialize() error {
|
||||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
HashFunc: salt.SHA1Hash,
|
HashFunc: salt.SHA1Hash,
|
||||||
|
Location: salt.DefaultLocation,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -174,6 +177,14 @@ func (b *backend) upgradeToSalted(view logical.Storage) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case salt.DefaultLocation:
|
||||||
|
// reread the salt
|
||||||
|
b.initialize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The App ID credential provider is used to perform authentication from
|
The App ID credential provider is used to perform authentication from
|
||||||
within applications or machine by pairing together two hard-to-guess
|
within applications or machine by pairing together two hard-to-guess
|
||||||
|
|
|
@ -14,7 +14,8 @@ type backend struct {
|
||||||
|
|
||||||
// The salt value to be used by the information to be accessed only
|
// The salt value to be used by the information to be accessed only
|
||||||
// by this backend.
|
// by this backend.
|
||||||
salt *salt.Salt
|
salt *salt.Salt
|
||||||
|
saltMutex sync.RWMutex
|
||||||
|
|
||||||
// The view to use when creating the salt
|
// The view to use when creating the salt
|
||||||
view logical.Storage
|
view logical.Storage
|
||||||
|
@ -92,14 +93,18 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
pathTidySecretID(b),
|
pathTidySecretID(b),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Init: b.initialize,
|
Init: b.initialize,
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backend) initialize() error {
|
func (b *backend) initialize() error {
|
||||||
|
b.saltMutex.Lock()
|
||||||
|
defer b.saltMutex.Unlock()
|
||||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
HashFunc: salt.SHA256Hash,
|
HashFunc: salt.SHA256Hash,
|
||||||
|
Location: salt.DefaultLocation,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -108,6 +113,14 @@ func (b *backend) initialize() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case salt.DefaultLocation:
|
||||||
|
// reread the salt
|
||||||
|
b.initialize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
|
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
|
||||||
// RoleRole backend utilizes this function to delete expired SecretID entries.
|
// RoleRole backend utilizes this function to delete expired SecretID entries.
|
||||||
// This could mean that the SecretID may live in the backend upto 1 min after its
|
// This could mean that the SecretID may live in the backend upto 1 min after its
|
||||||
|
|
|
@ -1939,7 +1939,9 @@ func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
b.saltMutex.RLock()
|
||||||
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
entry, err := logical.StorageEntryJSON(entryIndex, roleIDEntry)
|
entry, err := logical.StorageEntryJSON(entryIndex, roleIDEntry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1963,7 +1965,9 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
|
||||||
|
|
||||||
var result roleIDStorageEntry
|
var result roleIDStorageEntry
|
||||||
|
|
||||||
|
b.saltMutex.RLock()
|
||||||
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
if entry, err := s.Get(entryIndex); err != nil {
|
if entry, err := s.Get(entryIndex); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1987,7 +1991,9 @@ func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
b.saltMutex.RLock()
|
||||||
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
entryIndex := "role_id/" + b.salt.SaltID(roleID)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
return s.Delete(entryIndex)
|
return s.Delete(entryIndex)
|
||||||
}
|
}
|
||||||
|
|
|
@ -469,7 +469,9 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
|
||||||
var result secretIDAccessorStorageEntry
|
var result secretIDAccessorStorageEntry
|
||||||
|
|
||||||
// Create index entry, mapping the accessor to the token ID
|
// Create index entry, mapping the accessor to the token ID
|
||||||
|
b.saltMutex.RLock()
|
||||||
entryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
|
entryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
accessorLock := b.secretIDAccessorLock(secretIDAccessor)
|
accessorLock := b.secretIDAccessorLock(secretIDAccessor)
|
||||||
accessorLock.RLock()
|
accessorLock.RLock()
|
||||||
|
@ -498,7 +500,9 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
|
||||||
entry.SecretIDAccessor = accessorUUID
|
entry.SecretIDAccessor = accessorUUID
|
||||||
|
|
||||||
// Create index entry, mapping the accessor to the token ID
|
// Create index entry, mapping the accessor to the token ID
|
||||||
|
b.saltMutex.RLock()
|
||||||
entryIndex := "accessor/" + b.salt.SaltID(entry.SecretIDAccessor)
|
entryIndex := "accessor/" + b.salt.SaltID(entry.SecretIDAccessor)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
accessorLock := b.secretIDAccessorLock(accessorUUID)
|
accessorLock := b.secretIDAccessorLock(accessorUUID)
|
||||||
accessorLock.Lock()
|
accessorLock.Lock()
|
||||||
|
@ -517,7 +521,9 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
|
||||||
|
|
||||||
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
|
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
|
||||||
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
|
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
|
||||||
|
b.saltMutex.RLock()
|
||||||
accessorEntryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
|
accessorEntryIndex := "accessor/" + b.salt.SaltID(secretIDAccessor)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
accessorLock := b.secretIDAccessorLock(secretIDAccessor)
|
accessorLock := b.secretIDAccessorLock(secretIDAccessor)
|
||||||
accessorLock.Lock()
|
accessorLock.Lock()
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/service/ec2"
|
"github.com/aws/aws-sdk-go/service/ec2"
|
||||||
"github.com/aws/aws-sdk-go/service/iam"
|
"github.com/aws/aws-sdk-go/service/iam"
|
||||||
"github.com/hashicorp/vault/helper/salt"
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
|
@ -21,7 +20,6 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||||
|
|
||||||
type backend struct {
|
type backend struct {
|
||||||
*framework.Backend
|
*framework.Backend
|
||||||
Salt *salt.Salt
|
|
||||||
|
|
||||||
// Used during initialization to set the salt
|
// Used during initialization to set the salt
|
||||||
view logical.Storage
|
view logical.Storage
|
||||||
|
@ -105,24 +103,11 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
},
|
},
|
||||||
|
|
||||||
Invalidate: b.invalidate,
|
Invalidate: b.invalidate,
|
||||||
|
|
||||||
Init: b.initialize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backend) initialize() error {
|
|
||||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
|
||||||
HashFunc: salt.SHA256Hash,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
b.Salt = salt
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
||||||
// Currently this will be triggered once in a minute by the RollbackManager.
|
// Currently this will be triggered once in a minute by the RollbackManager.
|
||||||
//
|
//
|
||||||
|
|
|
@ -2,6 +2,7 @@ package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/salt"
|
"github.com/hashicorp/vault/helper/salt"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
|
@ -10,8 +11,9 @@ import (
|
||||||
|
|
||||||
type backend struct {
|
type backend struct {
|
||||||
*framework.Backend
|
*framework.Backend
|
||||||
view logical.Storage
|
view logical.Storage
|
||||||
salt *salt.Salt
|
salt *salt.Salt
|
||||||
|
saltMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||||
|
@ -57,14 +59,19 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
secretOTP(&b),
|
secretOTP(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
Init: b.Initialize,
|
Init: b.initialize,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
return &b, nil
|
return &b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backend) Initialize() error {
|
func (b *backend) initialize() error {
|
||||||
|
b.saltMutex.Lock()
|
||||||
|
defer b.saltMutex.Unlock()
|
||||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
HashFunc: salt.SHA256Hash,
|
HashFunc: salt.SHA256Hash,
|
||||||
|
Location: salt.DefaultLocation,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -73,6 +80,14 @@ func (b *backend) Initialize() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case salt.DefaultLocation:
|
||||||
|
// reread the salt
|
||||||
|
b.initialize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The SSH backend generates credentials allowing clients to establish SSH
|
The SSH backend generates credentials allowing clients to establish SSH
|
||||||
connections to remote hosts.
|
connections to remote hosts.
|
||||||
|
|
|
@ -207,6 +207,8 @@ func (b *backend) GenerateSaltedOTP() (string, string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
b.saltMutex.RLock()
|
||||||
|
defer b.saltMutex.RUnlock()
|
||||||
return str, b.salt.SaltID(str), nil
|
return str, b.salt.SaltID(str), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,9 @@ func (b *backend) pathVerifyWrite(req *logical.Request, d *framework.FieldData)
|
||||||
// Create the salt of OTP because entry would have been create with the
|
// Create the salt of OTP because entry would have been create with the
|
||||||
// salt and not directly of the OTP. Salt will yield the same value which
|
// salt and not directly of the OTP. Salt will yield the same value which
|
||||||
// because the seed is the same, the backend salt.
|
// because the seed is the same, the backend salt.
|
||||||
|
b.saltMutex.RLock()
|
||||||
otpSalted := b.salt.SaltID(otp)
|
otpSalted := b.salt.SaltID(otp)
|
||||||
|
b.saltMutex.RUnlock()
|
||||||
|
|
||||||
// Return nil if there is no entry found for the OTP
|
// Return nil if there is no entry found for the OTP
|
||||||
otpEntry, err := b.getOTP(req.Storage, otpSalted)
|
otpEntry, err := b.getOTP(req.Storage, otpSalted)
|
||||||
|
|
|
@ -33,6 +33,8 @@ func (b *backend) secretOTPRevoke(req *logical.Request, d *framework.FieldData)
|
||||||
return nil, fmt.Errorf("secret is missing internal data")
|
return nil, fmt.Errorf("secret is missing internal data")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.saltMutex.RLock()
|
||||||
|
defer b.saltMutex.RUnlock()
|
||||||
err := req.Storage.Delete("otp/" + b.salt.SaltID(otp))
|
err := req.Storage.Delete("otp/" + b.salt.SaltID(otp))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -42,7 +42,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// enableCredential is used to enable a new credential backend
|
// enableCredential is used to enable a new credential backend
|
||||||
func (c *Core) enableCredential(entry *MountEntry) error {
|
func (c *Core) enableCredential(entry *MountEntry, skipInitialization bool) error {
|
||||||
// Ensure we end the path in a slash
|
// Ensure we end the path in a slash
|
||||||
if !strings.HasSuffix(entry.Path, "/") {
|
if !strings.HasSuffix(entry.Path, "/") {
|
||||||
entry.Path += "/"
|
entry.Path += "/"
|
||||||
|
@ -99,8 +99,10 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
||||||
return fmt.Errorf("nil backend returned from %q factory", entry.Type)
|
return fmt.Errorf("nil backend returned from %q factory", entry.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := backend.Initialize(); err != nil {
|
if !skipInitialization {
|
||||||
return err
|
if err := backend.Initialize(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the auth table
|
// Update the auth table
|
||||||
|
|
|
@ -1210,7 +1210,7 @@ func (b *SystemBackend) handleMount(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt mount
|
// Attempt mount
|
||||||
if err := b.Core.mount(me); err != nil {
|
if err := b.Core.mount(me, false); err != nil {
|
||||||
b.Backend.Logger().Error("sys: mount failed", "path", me.Path, "error", err)
|
b.Backend.Logger().Error("sys: mount failed", "path", me.Path, "error", err)
|
||||||
return handleError(err)
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
@ -1642,7 +1642,7 @@ func (b *SystemBackend) handleEnableAuth(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt enabling
|
// Attempt enabling
|
||||||
if err := b.Core.enableCredential(me); err != nil {
|
if err := b.Core.enableCredential(me, false); err != nil {
|
||||||
b.Backend.Logger().Error("sys: enable auth mount failed", "path", me.Path, "error", err)
|
b.Backend.Logger().Error("sys: enable auth mount failed", "path", me.Path, "error", err)
|
||||||
return handleError(err)
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,7 +169,7 @@ func (e *MountEntry) Clone() *MountEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mount is used to mount a new backend to the mount table.
|
// Mount is used to mount a new backend to the mount table.
|
||||||
func (c *Core) mount(entry *MountEntry) error {
|
func (c *Core) mount(entry *MountEntry, skipInitialization bool) error {
|
||||||
// Ensure we end the path in a slash
|
// Ensure we end the path in a slash
|
||||||
if !strings.HasSuffix(entry.Path, "/") {
|
if !strings.HasSuffix(entry.Path, "/") {
|
||||||
entry.Path += "/"
|
entry.Path += "/"
|
||||||
|
@ -219,8 +219,10 @@ func (c *Core) mount(entry *MountEntry) error {
|
||||||
|
|
||||||
// Call initialize; this takes care of init tasks that must be run after
|
// Call initialize; this takes care of init tasks that must be run after
|
||||||
// the ignore paths are collected
|
// the ignore paths are collected
|
||||||
if err := backend.Initialize(); err != nil {
|
if !skipInitialization {
|
||||||
return err
|
if err := backend.Initialize(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newTable := c.mounts.shallowClone()
|
newTable := c.mounts.shallowClone()
|
||||||
|
|
Loading…
Reference in a new issue