More rep porting (#2391)

* More rep porting

* Add a bit more porting
This commit is contained in:
Jeff Mitchell 2017-02-16 23:09:39 -05:00 committed by GitHub
parent 494b4c844b
commit f37b6492d1
14 changed files with 418 additions and 103 deletions

View File

@ -1,7 +1,6 @@
package vault
import (
"encoding/json"
"errors"
"fmt"
"strings"
@ -17,6 +16,10 @@ const (
// can only be viewed or modified after an unseal.
coreAuthConfigPath = "core/auth"
// coreLocalAuthConfigPath is used to store credential configuration for
// local (non-replicated) mounts
coreLocalAuthConfigPath = "core/local-auth"
// credentialBarrierPrefix is the prefix to the UUID used in the
// barrier view for the credential backends.
credentialBarrierPrefix = "auth/"
@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
}
// Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID()
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
viewPath := credentialBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil {
return err
}
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
if err := backend.Initialize(); err != nil {
return err
}
@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
fullPath := credentialRoutePrefix + path
view := c.router.MatchingStorageView(fullPath)
if view == nil {
return false, fmt.Errorf("no matching backend")
return false, fmt.Errorf("no matching backend %s", fullPath)
}
// Mark the entry as tainted
@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
// loadCredentials is invoked as part of postUnseal to load the auth table
func (c *Core) loadCredentials() error {
authTable := &MountTable{}
localAuthTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreAuthConfigPath)
if err != nil {
c.logger.Error("core: failed to read auth table", "error", err)
return errLoadAuthFailed
}
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
c.logger.Error("core: failed to read local auth table", "error", err)
return errLoadAuthFailed
}
c.authLock.Lock()
defer c.authLock.Unlock()
@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
}
c.auth = authTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
c.logger.Error("core: failed to decode local auth table", "error", err)
return errLoadAuthFailed
}
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
}
// Done if we have restored the auth table
if c.auth != nil {
@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
}
}
nonLocalAuth := &MountTable{
Type: credentialTableType,
}
localAuth := &MountTable{
Type: credentialTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAuth.Entries = append(localAuth.Entries, entry)
} else {
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
}
}
// Marshal the table
raw, err := json.Marshal(table)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
if err != nil {
c.logger.Error("core: failed to encode auth table", "error", err)
c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreAuthConfigPath,
Value: raw,
Value: compressedBytes,
}
// Write to the physical backend
@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
c.logger.Error("core: failed to persist auth table", "error", err)
return err
}
// Repeat with local auth
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuthConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local auth table", "error", err)
return err
}
return nil
}
@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
}
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
viewPath := credentialBarrierPrefix + entry.UUID + "/"
view = NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Initialize the backend
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
return errLoadAuthFailed
}
if err := backend.Initialize(); err != nil {
return err
}
// Mount the backend
path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry, view)

View File

@ -2,8 +2,10 @@ package vault
import (
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
)
@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
}
}
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableCredential_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return &NoopBackend{}, nil
}
c.auth = &MountTable{
Type: credentialTableType,
Entries: []*MountEntry{
&MountEntry{
Table: credentialTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: credentialTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupCredentials()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) > 0 {
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
}
c.auth.Entries[1].Local = true
if err := c.persistAuth(c.auth); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) != 1 {
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
}
oldCredential := c.auth
if err := c.loadCredentials(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldCredential, c.auth) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
}
if len(c.auth.Entries) != 2 {
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
}
}
func TestCore_EnableCredential_twice_409(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
}
existed, err := c.disableCredential("foo")
if existed || err.Error() != "no matching backend" {
if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
t.Fatalf("existed: %v; err: %v", existed, err)
}

View File

@ -86,6 +86,11 @@ type SecurityBarrier interface {
// VerifyMaster is used to check if the given key matches the master key
VerifyMaster(key []byte) error
// SetMasterKey is used to directly set a new master key. This is used in
// repliated scenarios due to the chicken and egg problem of reloading the
// keyring from disk before we have the master key to decrypt it.
SetMasterKey(key []byte) error
// ReloadKeyring is used to re-read the underlying keyring.
// This is used for HA deployments to ensure the latest keyring
// is present in the leader.
@ -119,8 +124,14 @@ type SecurityBarrier interface {
// Rekey is used to change the master key used to protect the keyring
Rekey([]byte) error
// For replication we must send over the keyring, so this must be available
Keyring() (*Keyring, error)
// SecurityBarrier must provide the storage APIs
BarrierStorage
// SecurityBarrier must provide the encryption APIs
BarrierEncryptor
}
// BarrierStorage is the storage only interface required for a Barrier.
@ -139,6 +150,14 @@ type BarrierStorage interface {
List(prefix string) ([]string, error)
}
// BarrierEncryptor is the in memory only interface that does not actually
// use the underlying barrier. It is used for lower level modules like the
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
type BarrierEncryptor interface {
Encrypt(key string, plaintext []byte) ([]byte, error)
Decrypt(key string, ciphertext []byte) ([]byte, error)
}
// Entry is used to represent data stored by the security barrier
type Entry struct {
Key string

View File

@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
func (b *AESGCMBarrier) Rekey(key []byte) error {
b.l.Lock()
defer b.l.Unlock()
if b.sealed {
return ErrBarrierSealed
}
// Verify the key size
min, max := b.KeyLength()
if len(key) < min || len(key) > max {
return fmt.Errorf("Key size must be %d or %d", min, max)
newKeyring, err := b.updateMasterKeyCommon(key)
if err != nil {
return err
}
// Add a new encryption key
newKeyring := b.keyring.SetMasterKey(key)
// Persist the new keyring
if err := b.persistKeyring(newKeyring); err != nil {
return err
@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
return nil
}
// SetMasterKey updates the keyring's in-memory master key but does not persist
// anything to storage
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
b.l.Lock()
defer b.l.Unlock()
newKeyring, err := b.updateMasterKeyCommon(key)
if err != nil {
return err
}
// Swap the keyrings
oldKeyring := b.keyring
b.keyring = newKeyring
oldKeyring.Zeroize(false)
return nil
}
// Performs common tasks related to updating the master key; note that the lock
// must be held before calling this function
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
if b.sealed {
return nil, ErrBarrierSealed
}
// Verify the key size
min, max := b.KeyLength()
if len(key) < min || len(key) > max {
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
}
return b.keyring.SetMasterKey(key), nil
}
// Put is used to insert or update an entry
func (b *AESGCMBarrier) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
return nil, fmt.Errorf("version bytes mis-match")
}
}
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
term := b.keyring.ActiveTerm()
primary, err := b.aeadForTerm(term)
if err != nil {
return nil, err
}
ciphertext := b.encrypt(key, term, primary, plaintext)
return ciphertext, nil
}
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
// Decrypt the ciphertext
plain, err := b.decryptKeyring(key, ciphertext)
if err != nil {
return nil, fmt.Errorf("decryption failed: %v", err)
}
return plain, nil
}
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
return b.keyring.Clone(), nil
}

View File

@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
t.Fatalf("key length protection failed")
}
}
func TestEncrypt_BarrierEncryptor(t *testing.T) {
inm := physical.NewInmem(logger)
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey()
b.Initialize(key)
b.Unseal(key)
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
plain, err := b.Decrypt("foo", cipher)
if err != nil {
t.Fatalf("err: %v", err)
}
if string(plain) != "quick brown fox" {
t.Fatalf("bad: %s", plain)
}
}

View File

@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
// logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil {
return err
}
expandedKey := v.expandKey(entry.Key)
if v.readonly {
return logical.ErrReadOnly
}
nested := &Entry{
Key: v.expandKey(entry.Key),
Key: expandedKey,
Value: entry.Value,
}
return v.barrier.Put(nested)
@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
// logical.Storage impl.
func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil {
return err
}
return v.barrier.Delete(v.expandKey(key))
expandedKey := v.expandKey(key)
if v.readonly {
return logical.ErrReadOnly
}
return v.barrier.Delete(expandedKey)
}
// SubView constructs a nested sub-view using the given prefix

View File

@ -1,27 +1,19 @@
package vault
import "sort"
import (
"sort"
// Struct to identify user input errors.
// This is helpful in responding the appropriate status codes to clients
// from the HTTP endpoints.
type StatusBadRequest struct {
Err string
}
// Implementing error interface
func (s *StatusBadRequest) Error() string {
return s.Err
}
"github.com/hashicorp/vault/logical"
)
// Capabilities is used to fetch the capabilities of the given token on the given path
func (c *Core) Capabilities(token, path string) ([]string, error) {
if path == "" {
return nil, &StatusBadRequest{Err: "missing path"}
return nil, &logical.StatusBadRequest{Err: "missing path"}
}
if token == "" {
return nil, &StatusBadRequest{Err: "missing token"}
return nil, &logical.StatusBadRequest{Err: "missing token"}
}
te, err := c.tokenStore.Lookup(token)
@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
return nil, err
}
if te == nil {
return nil, &StatusBadRequest{Err: "invalid token"}
return nil, &logical.StatusBadRequest{Err: "invalid token"}
}
if te.Policies == nil {

View File

@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
return nil, fmt.Errorf("error initializing seal: %v", err)
}
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
if err != nil {
c.logger.Error("core: error generating shares", "error", err)
return nil, err
}
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Initialize the barrier
if err := c.barrier.Initialize(barrierKey); err != nil {
c.logger.Error("core: failed to initialize barrier", "error", err)
@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
// Ensure the barrier is re-sealed
defer func() {
// Defers are LIFO so we need to run this here too to ensure the stop
// happens before sealing. preSeal also stops, so we just make the
// stopping safe against multiple calls.
if err := c.barrier.Seal(); err != nil {
c.logger.Error("core: failed to seal barrier", "error", err)
}
}()
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Perform initial setup
if err := c.setupCluster(); err != nil {
c.logger.Error("core: cluster setup failed during init", "error", err)

View File

@ -237,7 +237,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
t.Fatalf("err: %v", err)
}
for i := 0; i < 3; i++ {
_, err = TestCoreUnseal(c, result.SecretShares[i])
_, err = TestCoreUnseal(c, TestKeyCopy(result.SecretShares[i]))
if err != nil {
t.Fatalf("err: %v", err)
}
@ -270,7 +270,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
// Provide the parts master
oldResult := result
for i := 0; i < 3; i++ {
result, err = c.RekeyUpdate(oldResult.SecretShares[i], rkconf.Nonce, recovery)
result, err = c.RekeyUpdate(TestKeyCopy(oldResult.SecretShares[i]), rkconf.Nonce, recovery)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -27,6 +27,9 @@ func (c *Core) startForwarding() error {
// Clean up in case we have transitioned from a client to a server
c.clearForwardingClients()
// Resolve locally to avoid races
ha := c.ha != nil
// Get our base handler (for our RPC server) and our wrapped handler (for
// straight HTTP/2 forwarding)
baseHandler, wrappedHandler := c.clusterHandlerSetupFunc()
@ -43,10 +46,13 @@ func (c *Core) startForwarding() error {
// Create our RPC server and register the request handler server
c.rpcServer = grpc.NewServer()
RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{
core: c,
handler: baseHandler,
})
if ha {
RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{
core: c,
handler: baseHandler,
})
}
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
@ -82,6 +88,7 @@ func (c *Core) startForwarding() error {
// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
defer tlsLn.Close()
if c.logger.IsInfo() {
c.logger.Info("core/startClusterListener: serving cluster requests", "cluster_listen_address", tlsLn.Addr())
@ -89,7 +96,6 @@ func (c *Core) startForwarding() error {
for {
if atomic.LoadUint32(&shutdown) > 0 {
tlsLn.Close()
return
}
@ -100,10 +106,11 @@ func (c *Core) startForwarding() error {
// Accept the connection
conn, err := tlsLn.Accept()
if conn != nil {
// Always defer although it may be closed ahead of time
defer conn.Close()
}
if err != nil {
if conn != nil {
conn.Close()
}
continue
}
@ -123,19 +130,29 @@ func (c *Core) startForwarding() error {
switch tlsConn.ConnectionState().NegotiatedProtocol {
case "h2":
if !ha {
conn.Close()
continue
}
c.logger.Debug("core/startClusterListener/Accept: got h2 connection")
go fws.ServeConn(conn, &http2.ServeConnOpts{
Handler: wrappedHandler,
})
case "req_fw_sb-act_v1":
if !ha {
conn.Close()
continue
}
c.logger.Debug("core/startClusterListener/Accept: got req_fw_sb-act_v1 connection")
go fws.ServeConn(conn, &http2.ServeConnOpts{
Handler: c.rpcServer,
})
default:
c.logger.Debug("core/startClusterListener/Accept: unknown negotiated protocol")
c.logger.Debug("core: unknown negotiated protocol on cluster port")
conn.Close()
continue
}
@ -154,8 +171,9 @@ func (c *Core) startForwarding() error {
<-c.clusterListenerShutdownCh
// Stop the RPC server
c.logger.Info("core: shutting down forwarding rpc listeners")
c.rpcServer.Stop()
c.logger.Info("core/startClusterListener: shutting down listeners")
c.logger.Info("core: forwarding rpc listeners stopped")
// Set the shutdown flag. This will cause the listeners to shut down
// within the deadline in clusterListenerAcceptDeadline
@ -163,7 +181,7 @@ func (c *Core) startForwarding() error {
// Wait for them all to shut down
shutdownWg.Wait()
c.logger.Info("core/startClusterListener: listeners successfully shut down")
c.logger.Info("core: rpc listeners successfully shut down")
// Tell the main thread that shutdown is done.
c.clusterListenerShutdownSuccessCh <- struct{}{}
@ -223,6 +241,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
// It's not really insecure, but we have to dial manually to get the
// ALPN header right. It's just "insecure" because GRPC isn't managing
// the TLS state.
ctx, cancelFunc := context.WithCancel(context.Background())
c.rpcClientConnCancelFunc = cancelFunc
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure())

View File

@ -184,7 +184,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
}
// Route the request
resp, err := c.router.Route(req)
resp, routeErr := c.router.Route(req)
if resp != nil {
// If wrapping is used, use the shortest between the request and response
var wrapTTL time.Duration
@ -306,8 +306,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
}
// Return the response and error
if err != nil {
retErr = multierror.Append(retErr, err)
if routeErr != nil {
retErr = multierror.Append(retErr, routeErr)
}
return resp, auth, retErr
}
@ -331,7 +331,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
}
// Route the request
resp, err := c.router.Route(req)
resp, routeErr := c.router.Route(req)
if resp != nil {
// If wrapping is used, use the shortest between the request and response
var wrapTTL time.Duration
@ -446,5 +446,5 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
req.DisplayName = auth.DisplayName
}
return resp, auth, err
return resp, auth, routeErr
}

View File

@ -243,8 +243,12 @@ func testTokenStore(t testing.TB, c *Core) *TokenStore {
me.UUID = meUUID
view := NewBarrierView(c.barrier, credentialBarrierPrefix+me.UUID+"/")
sysView := c.mountEntrySysView(me)
tokenstore, _ := c.newCredentialBackend("token", c.mountEntrySysView(me), view, nil)
tokenstore, _ := c.newCredentialBackend("token", sysView, view, nil)
if err := tokenstore.Initialize(); err != nil {
panic(err)
}
ts := tokenstore.(*TokenStore)
router := NewRouter()

View File

@ -109,19 +109,10 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
t.policyLookupFunc = c.policyStore.GetPolicy
}
// Setup the salt
salt, err := salt.NewSalt(view, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return nil, err
}
t.salt = salt
t.tokenLocks = map[string]*sync.RWMutex{}
// Create 256 locks
if err = locksutil.CreateLocks(t.tokenLocks, 256); err != nil {
if err := locksutil.CreateLocks(t.tokenLocks, 256); err != nil {
return nil, fmt.Errorf("failed to create locks: %v", err)
}
@ -136,6 +127,15 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
"revoke-orphan/*",
"accessors*",
},
// Most token store items are local since tokens are local, but a
// notable exception is roles
LocalStorage: []string{
lookupPrefix,
accessorPrefix,
parentPrefix,
"salt",
},
},
Paths: []*framework.Path{
@ -467,6 +467,8 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
HelpDescription: strings.TrimSpace(tokenTidyDesc),
},
},
Init: t.Initialize,
}
t.Backend.Setup(config)
@ -474,6 +476,19 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
return t, nil
}
func (ts *TokenStore) Initialize() error {
// Setup the salt
salt, err := salt.NewSalt(ts.view, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return err
}
ts.salt = salt
return nil
}
// TokenEntry is used to represent a given token
type TokenEntry struct {
// ID of this entry, generally a random UUID
@ -1085,7 +1100,7 @@ func (ts *TokenStore) lookupBySaltedAccessor(saltedAccessor string) (accessorEnt
return aEntry, fmt.Errorf("failed to read index using accessor: %s", err)
}
if entry == nil {
return aEntry, &StatusBadRequest{Err: "invalid accessor"}
return aEntry, &logical.StatusBadRequest{Err: "invalid accessor"}
}
err = jsonutil.DecodeJSON(entry.Value, &aEntry)
@ -1225,7 +1240,7 @@ func (ts *TokenStore) handleUpdateLookupAccessor(req *logical.Request, data *fra
if accessor == "" {
accessor = data.Get("urlaccessor").(string)
if accessor == "" {
return nil, &StatusBadRequest{Err: "missing accessor"}
return nil, &logical.StatusBadRequest{Err: "missing accessor"}
}
urlaccessor = true
}
@ -1279,7 +1294,7 @@ func (ts *TokenStore) handleUpdateRevokeAccessor(req *logical.Request, data *fra
if accessor == "" {
accessor = data.Get("urlaccessor").(string)
if accessor == "" {
return nil, &StatusBadRequest{Err: "missing accessor"}
return nil, &logical.StatusBadRequest{Err: "missing accessor"}
}
urlaccessor = true
}

View File

@ -437,6 +437,9 @@ func TestTokenStore_CreateLookup(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
if err := ts2.Initialize(); err != nil {
t.Fatalf("err: %v", err)
}
// Should still match
out, err = ts2.Lookup(ent.ID)
@ -476,6 +479,9 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
if err := ts2.Initialize(); err != nil {
t.Fatalf("err: %v", err)
}
// Should still match
out, err = ts2.Lookup(ent.ID)