parent
494b4c844b
commit
f37b6492d1
|
@ -1,7 +1,6 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -17,6 +16,10 @@ const (
|
|||
// can only be viewed or modified after an unseal.
|
||||
coreAuthConfigPath = "core/auth"
|
||||
|
||||
// coreLocalAuthConfigPath is used to store credential configuration for
|
||||
// local (non-replicated) mounts
|
||||
coreLocalAuthConfigPath = "core/local-auth"
|
||||
|
||||
// credentialBarrierPrefix is the prefix to the UUID used in the
|
||||
// barrier view for the credential backends.
|
||||
credentialBarrierPrefix = "auth/"
|
||||
|
@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
|||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if entry.UUID == "" {
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
}
|
||||
|
||||
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
// Create the new backend
|
||||
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
||||
|
||||
// Create the new backend
|
||||
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
||||
if err != nil {
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
|
|||
fullPath := credentialRoutePrefix + path
|
||||
view := c.router.MatchingStorageView(fullPath)
|
||||
if view == nil {
|
||||
return false, fmt.Errorf("no matching backend")
|
||||
return false, fmt.Errorf("no matching backend %s", fullPath)
|
||||
}
|
||||
|
||||
// Mark the entry as tainted
|
||||
|
@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
|
|||
// loadCredentials is invoked as part of postUnseal to load the auth table
|
||||
func (c *Core) loadCredentials() error {
|
||||
authTable := &MountTable{}
|
||||
localAuthTable := &MountTable{}
|
||||
|
||||
// Load the existing mount table
|
||||
raw, err := c.barrier.Get(coreAuthConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read local auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
|
|||
}
|
||||
c.auth = authTable
|
||||
}
|
||||
if rawLocal != nil {
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
|
||||
c.logger.Error("core: failed to decode local auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
|
||||
}
|
||||
|
||||
// Done if we have restored the auth table
|
||||
if c.auth != nil {
|
||||
|
@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
|
|||
}
|
||||
}
|
||||
|
||||
nonLocalAuth := &MountTable{
|
||||
Type: credentialTableType,
|
||||
}
|
||||
|
||||
localAuth := &MountTable{
|
||||
Type: credentialTableType,
|
||||
}
|
||||
|
||||
for _, entry := range table.Entries {
|
||||
if entry.Local {
|
||||
localAuth.Entries = append(localAuth.Entries, entry)
|
||||
} else {
|
||||
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the table
|
||||
raw, err := json.Marshal(table)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode auth table", "error", err)
|
||||
c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreAuthConfigPath,
|
||||
Value: raw,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
// Write to the physical backend
|
||||
|
@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
|
|||
c.logger.Error("core: failed to persist auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Repeat with local auth
|
||||
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
entry = &Entry{
|
||||
Key: coreLocalAuthConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(entry); err != nil {
|
||||
c.logger.Error("core: failed to persist local auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
|
|||
}
|
||||
|
||||
// Create a barrier view using the UUID
|
||||
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||
view = NewBarrierView(c.barrier, viewPath)
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
// Initialize the backend
|
||||
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
||||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mount the backend
|
||||
path := credentialRoutePrefix + entry.Path
|
||||
err = c.router.Mount(backend, path, entry, view)
|
||||
|
|
|
@ -2,8 +2,10 @@ package vault
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that the local table actually gets populated as expected with local
|
||||
// entries, and that upon reading the entries from both are recombined
|
||||
// correctly
|
||||
func TestCore_EnableCredential_Local(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return &NoopBackend{}, nil
|
||||
}
|
||||
|
||||
c.auth = &MountTable{
|
||||
Type: credentialTableType,
|
||||
Entries: []*MountEntry{
|
||||
&MountEntry{
|
||||
Table: credentialTableType,
|
||||
Path: "noop/",
|
||||
Type: "noop",
|
||||
UUID: "abcd",
|
||||
},
|
||||
&MountEntry{
|
||||
Table: credentialTableType,
|
||||
Path: "noop2/",
|
||||
Type: "noop",
|
||||
UUID: "bcde",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Both should set up successfully
|
||||
err := c.setupCredentials()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local credential")
|
||||
}
|
||||
localCredentialTable := &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localCredentialTable.Entries) > 0 {
|
||||
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
|
||||
}
|
||||
|
||||
c.auth.Entries[1].Local = true
|
||||
if err := c.persistAuth(c.auth); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local credential")
|
||||
}
|
||||
localCredentialTable = &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localCredentialTable.Entries) != 1 {
|
||||
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
|
||||
}
|
||||
|
||||
oldCredential := c.auth
|
||||
if err := c.loadCredentials(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(oldCredential, c.auth) {
|
||||
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
|
||||
}
|
||||
|
||||
if len(c.auth.Entries) != 2 {
|
||||
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_EnableCredential_twice_409(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
|
@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
|
|||
}
|
||||
|
||||
existed, err := c.disableCredential("foo")
|
||||
if existed || err.Error() != "no matching backend" {
|
||||
if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
|
||||
t.Fatalf("existed: %v; err: %v", existed, err)
|
||||
}
|
||||
|
||||
|
|
|
@ -86,6 +86,11 @@ type SecurityBarrier interface {
|
|||
// VerifyMaster is used to check if the given key matches the master key
|
||||
VerifyMaster(key []byte) error
|
||||
|
||||
// SetMasterKey is used to directly set a new master key. This is used in
|
||||
// repliated scenarios due to the chicken and egg problem of reloading the
|
||||
// keyring from disk before we have the master key to decrypt it.
|
||||
SetMasterKey(key []byte) error
|
||||
|
||||
// ReloadKeyring is used to re-read the underlying keyring.
|
||||
// This is used for HA deployments to ensure the latest keyring
|
||||
// is present in the leader.
|
||||
|
@ -119,8 +124,14 @@ type SecurityBarrier interface {
|
|||
// Rekey is used to change the master key used to protect the keyring
|
||||
Rekey([]byte) error
|
||||
|
||||
// For replication we must send over the keyring, so this must be available
|
||||
Keyring() (*Keyring, error)
|
||||
|
||||
// SecurityBarrier must provide the storage APIs
|
||||
BarrierStorage
|
||||
|
||||
// SecurityBarrier must provide the encryption APIs
|
||||
BarrierEncryptor
|
||||
}
|
||||
|
||||
// BarrierStorage is the storage only interface required for a Barrier.
|
||||
|
@ -139,6 +150,14 @@ type BarrierStorage interface {
|
|||
List(prefix string) ([]string, error)
|
||||
}
|
||||
|
||||
// BarrierEncryptor is the in memory only interface that does not actually
|
||||
// use the underlying barrier. It is used for lower level modules like the
|
||||
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
|
||||
type BarrierEncryptor interface {
|
||||
Encrypt(key string, plaintext []byte) ([]byte, error)
|
||||
Decrypt(key string, ciphertext []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// Entry is used to represent data stored by the security barrier
|
||||
type Entry struct {
|
||||
Key string
|
||||
|
|
|
@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
|
|||
func (b *AESGCMBarrier) Rekey(key []byte) error {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
if b.sealed {
|
||||
return ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Verify the key size
|
||||
min, max := b.KeyLength()
|
||||
if len(key) < min || len(key) > max {
|
||||
return fmt.Errorf("Key size must be %d or %d", min, max)
|
||||
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add a new encryption key
|
||||
newKeyring := b.keyring.SetMasterKey(key)
|
||||
|
||||
// Persist the new keyring
|
||||
if err := b.persistKeyring(newKeyring); err != nil {
|
||||
return err
|
||||
|
@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetMasterKey updates the keyring's in-memory master key but does not persist
|
||||
// anything to storage
|
||||
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Swap the keyrings
|
||||
oldKeyring := b.keyring
|
||||
b.keyring = newKeyring
|
||||
oldKeyring.Zeroize(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs common tasks related to updating the master key; note that the lock
|
||||
// must be held before calling this function
|
||||
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Verify the key size
|
||||
min, max := b.KeyLength()
|
||||
if len(key) < min || len(key) > max {
|
||||
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
|
||||
}
|
||||
|
||||
return b.keyring.SetMasterKey(key), nil
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry
|
||||
func (b *AESGCMBarrier) Put(entry *Entry) error {
|
||||
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
|
||||
|
@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
|
|||
return nil, fmt.Errorf("version bytes mis-match")
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
|
||||
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
term := b.keyring.ActiveTerm()
|
||||
primary, err := b.aeadForTerm(term)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := b.encrypt(key, term, primary, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
|
||||
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Decrypt the ciphertext
|
||||
plain, err := b.decryptKeyring(key, ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %v", err)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
return b.keyring.Clone(), nil
|
||||
}
|
||||
|
|
|
@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
|
|||
t.Fatalf("key length protection failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_BarrierEncryptor(t *testing.T) {
|
||||
inm := physical.NewInmem(logger)
|
||||
b, err := NewAESGCMBarrier(inm)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Initialize and unseal
|
||||
key, _ := b.GenerateKey()
|
||||
b.Initialize(key)
|
||||
b.Unseal(key)
|
||||
|
||||
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
plain, err := b.Decrypt("foo", cipher)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if string(plain) != "quick brown fox" {
|
||||
t.Fatalf("bad: %s", plain)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
|
|||
|
||||
// logical.Storage impl.
|
||||
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
if err := v.sanityCheck(entry.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expandedKey := v.expandKey(entry.Key)
|
||||
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
|
||||
nested := &Entry{
|
||||
Key: v.expandKey(entry.Key),
|
||||
Key: expandedKey,
|
||||
Value: entry.Value,
|
||||
}
|
||||
return v.barrier.Put(nested)
|
||||
|
@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
|||
|
||||
// logical.Storage impl.
|
||||
func (v *BarrierView) Delete(key string) error {
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
if err := v.sanityCheck(key); err != nil {
|
||||
return err
|
||||
}
|
||||
return v.barrier.Delete(v.expandKey(key))
|
||||
|
||||
expandedKey := v.expandKey(key)
|
||||
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
|
||||
|
||||
return v.barrier.Delete(expandedKey)
|
||||
}
|
||||
|
||||
// SubView constructs a nested sub-view using the given prefix
|
||||
|
|
|
@ -1,27 +1,19 @@
|
|||
package vault
|
||||
|
||||
import "sort"
|
||||
import (
|
||||
"sort"
|
||||
|
||||
// Struct to identify user input errors.
|
||||
// This is helpful in responding the appropriate status codes to clients
|
||||
// from the HTTP endpoints.
|
||||
type StatusBadRequest struct {
|
||||
Err string
|
||||
}
|
||||
|
||||
// Implementing error interface
|
||||
func (s *StatusBadRequest) Error() string {
|
||||
return s.Err
|
||||
}
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// Capabilities is used to fetch the capabilities of the given token on the given path
|
||||
func (c *Core) Capabilities(token, path string) ([]string, error) {
|
||||
if path == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing path"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing path"}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing token"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing token"}
|
||||
}
|
||||
|
||||
te, err := c.tokenStore.Lookup(token)
|
||||
|
@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
if te == nil {
|
||||
return nil, &StatusBadRequest{Err: "invalid token"}
|
||||
return nil, &logical.StatusBadRequest{Err: "invalid token"}
|
||||
}
|
||||
|
||||
if te.Policies == nil {
|
||||
|
|
|
@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
|||
return nil, fmt.Errorf("error initializing seal: %v", err)
|
||||
}
|
||||
|
||||
err = c.seal.SetBarrierConfig(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
||||
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
||||
}
|
||||
|
||||
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: error generating shares", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we are storing shares, pop them out of the returned results and push
|
||||
// them through the seal
|
||||
if barrierConfig.StoredShares > 0 {
|
||||
var keysToStore [][]byte
|
||||
for i := 0; i < barrierConfig.StoredShares; i++ {
|
||||
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
||||
barrierUnsealKeys = barrierUnsealKeys[1:]
|
||||
}
|
||||
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
||||
c.logger.Error("core: failed to store keys", "error", err)
|
||||
return nil, fmt.Errorf("failed to store keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results := &InitResult{
|
||||
SecretShares: barrierUnsealKeys,
|
||||
}
|
||||
|
||||
// Initialize the barrier
|
||||
if err := c.barrier.Initialize(barrierKey); err != nil {
|
||||
c.logger.Error("core: failed to initialize barrier", "error", err)
|
||||
|
@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
|||
|
||||
// Ensure the barrier is re-sealed
|
||||
defer func() {
|
||||
// Defers are LIFO so we need to run this here too to ensure the stop
|
||||
// happens before sealing. preSeal also stops, so we just make the
|
||||
// stopping safe against multiple calls.
|
||||
if err := c.barrier.Seal(); err != nil {
|
||||
c.logger.Error("core: failed to seal barrier", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = c.seal.SetBarrierConfig(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
||||
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
||||
}
|
||||
|
||||
// If we are storing shares, pop them out of the returned results and push
|
||||
// them through the seal
|
||||
if barrierConfig.StoredShares > 0 {
|
||||
var keysToStore [][]byte
|
||||
for i := 0; i < barrierConfig.StoredShares; i++ {
|
||||
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
||||
barrierUnsealKeys = barrierUnsealKeys[1:]
|
||||
}
|
||||
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
||||
c.logger.Error("core: failed to store keys", "error", err)
|
||||
return nil, fmt.Errorf("failed to store keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results := &InitResult{
|
||||
SecretShares: barrierUnsealKeys,
|
||||
}
|
||||
|
||||
// Perform initial setup
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.logger.Error("core: cluster setup failed during init", "error", err)
|
||||
|
|
|
@ -237,7 +237,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err = TestCoreUnseal(c, result.SecretShares[i])
|
||||
_, err = TestCoreUnseal(c, TestKeyCopy(result.SecretShares[i]))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -270,7 +270,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
|
|||
// Provide the parts master
|
||||
oldResult := result
|
||||
for i := 0; i < 3; i++ {
|
||||
result, err = c.RekeyUpdate(oldResult.SecretShares[i], rkconf.Nonce, recovery)
|
||||
result, err = c.RekeyUpdate(TestKeyCopy(oldResult.SecretShares[i]), rkconf.Nonce, recovery)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -27,6 +27,9 @@ func (c *Core) startForwarding() error {
|
|||
// Clean up in case we have transitioned from a client to a server
|
||||
c.clearForwardingClients()
|
||||
|
||||
// Resolve locally to avoid races
|
||||
ha := c.ha != nil
|
||||
|
||||
// Get our base handler (for our RPC server) and our wrapped handler (for
|
||||
// straight HTTP/2 forwarding)
|
||||
baseHandler, wrappedHandler := c.clusterHandlerSetupFunc()
|
||||
|
@ -43,10 +46,13 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
// Create our RPC server and register the request handler server
|
||||
c.rpcServer = grpc.NewServer()
|
||||
RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{
|
||||
core: c,
|
||||
handler: baseHandler,
|
||||
})
|
||||
|
||||
if ha {
|
||||
RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{
|
||||
core: c,
|
||||
handler: baseHandler,
|
||||
})
|
||||
}
|
||||
|
||||
// Create the HTTP/2 server that will be shared by both RPC and regular
|
||||
// duties. Doing it this way instead of listening via the server and gRPC
|
||||
|
@ -82,6 +88,7 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
// Wrap the listener with TLS
|
||||
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
||||
defer tlsLn.Close()
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("core/startClusterListener: serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
||||
|
@ -89,7 +96,6 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
for {
|
||||
if atomic.LoadUint32(&shutdown) > 0 {
|
||||
tlsLn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -100,10 +106,11 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
// Accept the connection
|
||||
conn, err := tlsLn.Accept()
|
||||
if conn != nil {
|
||||
// Always defer although it may be closed ahead of time
|
||||
defer conn.Close()
|
||||
}
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -123,19 +130,29 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
switch tlsConn.ConnectionState().NegotiatedProtocol {
|
||||
case "h2":
|
||||
if !ha {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Debug("core/startClusterListener/Accept: got h2 connection")
|
||||
go fws.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: wrappedHandler,
|
||||
})
|
||||
|
||||
case "req_fw_sb-act_v1":
|
||||
if !ha {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Debug("core/startClusterListener/Accept: got req_fw_sb-act_v1 connection")
|
||||
go fws.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: c.rpcServer,
|
||||
})
|
||||
|
||||
default:
|
||||
c.logger.Debug("core/startClusterListener/Accept: unknown negotiated protocol")
|
||||
c.logger.Debug("core: unknown negotiated protocol on cluster port")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
@ -154,8 +171,9 @@ func (c *Core) startForwarding() error {
|
|||
<-c.clusterListenerShutdownCh
|
||||
|
||||
// Stop the RPC server
|
||||
c.logger.Info("core: shutting down forwarding rpc listeners")
|
||||
c.rpcServer.Stop()
|
||||
c.logger.Info("core/startClusterListener: shutting down listeners")
|
||||
c.logger.Info("core: forwarding rpc listeners stopped")
|
||||
|
||||
// Set the shutdown flag. This will cause the listeners to shut down
|
||||
// within the deadline in clusterListenerAcceptDeadline
|
||||
|
@ -163,7 +181,7 @@ func (c *Core) startForwarding() error {
|
|||
|
||||
// Wait for them all to shut down
|
||||
shutdownWg.Wait()
|
||||
c.logger.Info("core/startClusterListener: listeners successfully shut down")
|
||||
c.logger.Info("core: rpc listeners successfully shut down")
|
||||
|
||||
// Tell the main thread that shutdown is done.
|
||||
c.clusterListenerShutdownSuccessCh <- struct{}{}
|
||||
|
@ -223,6 +241,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
|
|||
// It's not really insecure, but we have to dial manually to get the
|
||||
// ALPN header right. It's just "insecure" because GRPC isn't managing
|
||||
// the TLS state.
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
c.rpcClientConnCancelFunc = cancelFunc
|
||||
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure())
|
||||
|
|
|
@ -184,7 +184,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
}
|
||||
|
||||
// Route the request
|
||||
resp, err := c.router.Route(req)
|
||||
resp, routeErr := c.router.Route(req)
|
||||
if resp != nil {
|
||||
// If wrapping is used, use the shortest between the request and response
|
||||
var wrapTTL time.Duration
|
||||
|
@ -306,8 +306,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
}
|
||||
|
||||
// Return the response and error
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
if routeErr != nil {
|
||||
retErr = multierror.Append(retErr, routeErr)
|
||||
}
|
||||
return resp, auth, retErr
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
}
|
||||
|
||||
// Route the request
|
||||
resp, err := c.router.Route(req)
|
||||
resp, routeErr := c.router.Route(req)
|
||||
if resp != nil {
|
||||
// If wrapping is used, use the shortest between the request and response
|
||||
var wrapTTL time.Duration
|
||||
|
@ -446,5 +446,5 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
req.DisplayName = auth.DisplayName
|
||||
}
|
||||
|
||||
return resp, auth, err
|
||||
return resp, auth, routeErr
|
||||
}
|
||||
|
|
|
@ -243,8 +243,12 @@ func testTokenStore(t testing.TB, c *Core) *TokenStore {
|
|||
me.UUID = meUUID
|
||||
|
||||
view := NewBarrierView(c.barrier, credentialBarrierPrefix+me.UUID+"/")
|
||||
sysView := c.mountEntrySysView(me)
|
||||
|
||||
tokenstore, _ := c.newCredentialBackend("token", c.mountEntrySysView(me), view, nil)
|
||||
tokenstore, _ := c.newCredentialBackend("token", sysView, view, nil)
|
||||
if err := tokenstore.Initialize(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ts := tokenstore.(*TokenStore)
|
||||
|
||||
router := NewRouter()
|
||||
|
|
|
@ -109,19 +109,10 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
|
|||
t.policyLookupFunc = c.policyStore.GetPolicy
|
||||
}
|
||||
|
||||
// Setup the salt
|
||||
salt, err := salt.NewSalt(view, &salt.Config{
|
||||
HashFunc: salt.SHA1Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.salt = salt
|
||||
|
||||
t.tokenLocks = map[string]*sync.RWMutex{}
|
||||
|
||||
// Create 256 locks
|
||||
if err = locksutil.CreateLocks(t.tokenLocks, 256); err != nil {
|
||||
if err := locksutil.CreateLocks(t.tokenLocks, 256); err != nil {
|
||||
return nil, fmt.Errorf("failed to create locks: %v", err)
|
||||
}
|
||||
|
||||
|
@ -136,6 +127,15 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
|
|||
"revoke-orphan/*",
|
||||
"accessors*",
|
||||
},
|
||||
|
||||
// Most token store items are local since tokens are local, but a
|
||||
// notable exception is roles
|
||||
LocalStorage: []string{
|
||||
lookupPrefix,
|
||||
accessorPrefix,
|
||||
parentPrefix,
|
||||
"salt",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
|
@ -467,6 +467,8 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
|
|||
HelpDescription: strings.TrimSpace(tokenTidyDesc),
|
||||
},
|
||||
},
|
||||
|
||||
Init: t.Initialize,
|
||||
}
|
||||
|
||||
t.Backend.Setup(config)
|
||||
|
@ -474,6 +476,19 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func (ts *TokenStore) Initialize() error {
|
||||
// Setup the salt
|
||||
salt, err := salt.NewSalt(ts.view, &salt.Config{
|
||||
HashFunc: salt.SHA1Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts.salt = salt
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenEntry is used to represent a given token
|
||||
type TokenEntry struct {
|
||||
// ID of this entry, generally a random UUID
|
||||
|
@ -1085,7 +1100,7 @@ func (ts *TokenStore) lookupBySaltedAccessor(saltedAccessor string) (accessorEnt
|
|||
return aEntry, fmt.Errorf("failed to read index using accessor: %s", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return aEntry, &StatusBadRequest{Err: "invalid accessor"}
|
||||
return aEntry, &logical.StatusBadRequest{Err: "invalid accessor"}
|
||||
}
|
||||
|
||||
err = jsonutil.DecodeJSON(entry.Value, &aEntry)
|
||||
|
@ -1225,7 +1240,7 @@ func (ts *TokenStore) handleUpdateLookupAccessor(req *logical.Request, data *fra
|
|||
if accessor == "" {
|
||||
accessor = data.Get("urlaccessor").(string)
|
||||
if accessor == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing accessor"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing accessor"}
|
||||
}
|
||||
urlaccessor = true
|
||||
}
|
||||
|
@ -1279,7 +1294,7 @@ func (ts *TokenStore) handleUpdateRevokeAccessor(req *logical.Request, data *fra
|
|||
if accessor == "" {
|
||||
accessor = data.Get("urlaccessor").(string)
|
||||
if accessor == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing accessor"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing accessor"}
|
||||
}
|
||||
urlaccessor = true
|
||||
}
|
||||
|
|
|
@ -437,6 +437,9 @@ func TestTokenStore_CreateLookup(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := ts2.Initialize(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Should still match
|
||||
out, err = ts2.Lookup(ent.ID)
|
||||
|
@ -476,6 +479,9 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := ts2.Initialize(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Should still match
|
||||
out, err = ts2.Lookup(ent.ID)
|
||||
|
|
Loading…
Reference in New Issue