More porting from rep (#2389)

* More porting from rep

* Address feedback
This commit is contained in:
Jeff Mitchell 2017-02-16 20:13:19 -05:00 committed by GitHub
parent c81582fea0
commit 494b4c844b
8 changed files with 533 additions and 95 deletions

View File

@ -2,7 +2,6 @@ package vault
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"strings"
@ -26,6 +25,10 @@ const (
// can only be viewed or modified after an unseal.
coreAuditConfigPath = "core/audit"
// coreLocalAuditConfigPath is used to store audit information for local
// (non-replicated) mounts
coreLocalAuditConfigPath = "core/local-audit"
// auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends.
auditBarrierPrefix = "audit/"
@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
}
// Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
// Lookup the new backend
backend, err := c.newAuditBackend(entry, view, entry.Options)
@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
c.removeAuditReloadFunc(entry)
// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}
// Update the audit table
if err := c.persistAudit(newTable); err != nil {
return true, errors.New("failed to update audit table")
@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
if c.logger.IsInfo() {
c.logger.Info("core: disabled audit backend", "path", path)
}
return true, nil
}
// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error {
auditTable := &MountTable{}
localAuditTable := &MountTable{}
// Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath)
@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
c.logger.Error("core: failed to read audit table", "error", err)
return errLoadAuditFailed
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
c.logger.Error("core: failed to read local audit table", "error", err)
return errLoadAuditFailed
}
c.auditLock.Lock()
defer c.auditLock.Unlock()
@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
}
c.audit = auditTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
c.logger.Error("core: failed to decode local audit table", "error", err)
return errLoadAuditFailed
}
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
}
// Done if we have restored the audit table
if c.audit != nil {
@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
}
}
nonLocalAudit := &MountTable{
Type: auditTableType,
}
localAudit := &MountTable{
Type: auditTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAudit.Entries = append(localAudit.Entries, entry)
} else {
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
}
}
// Marshal the table
raw, err := json.Marshal(table)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode audit table", "error", err)
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreAuditConfigPath,
Value: raw,
Value: compressedBytes,
}
// Write to the physical backend
@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
c.logger.Error("core: failed to persist audit table", "error", err)
return err
}
// Repeat with local audit
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuditConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local audit table", "error", err)
return err
}
return nil
}
@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
// Initialize the backend
audit, err := c.newAuditBackend(entry, view, entry.Options)

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
}
}
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableAudit_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return nil, fmt.Errorf("failing enabling")
}
c.audit = &MountTable{
Type: auditTableType,
Entries: []*MountEntry{
&MountEntry{
Table: auditTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: auditTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupAudits()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) > 0 {
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
}
c.audit.Entries[1].Local = true
if err := c.persistAudit(c.audit); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) != 1 {
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
}
oldAudit := c.audit
if err := c.loadAudits(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldAudit, c.audit) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
}
if len(c.audit.Entries) != 2 {
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
}
}
func TestCore_DisableAudit(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
// Verify matching mount tables
if !reflect.DeepEqual(c.audit, c2.audit) {
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
}
}

View File

@ -43,7 +43,7 @@ var (
// This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct {
Type string `json:"type"`
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
c.logger.Info("core/stopClusterListener: success")
}
// ClusterTLSConfig generates a TLS configuration based on the local cluster
// key and cert.
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
// cluster key and cert.
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
cluster, err := c.Cluster()
if err != nil {
return nil, err
}
if cluster == nil {
return nil, fmt.Errorf("cluster information is nil")
return nil, fmt.Errorf("local cluster information is nil")
}
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
return nil, fmt.Errorf("cluster certificate is nil")
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
var parsedCert *x509.Certificate
if forwarding {
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}
// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
}
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
c.clusterParamsLock.RLock()
defer c.clusterParamsLock.RUnlock()
// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
return &tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
},
},
RootCAs: c.clusterCertPool,
ServerName: parsedCert.Subject.CommonName,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}
return nil, nil
}
var clientCertificates []tls.Certificate
if forwarding {
clientCertificates = append(clientCertificates, tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
})
}
tlsConfig := &tls.Config{
// We need this here for the client side
Certificates: clientCertificates,
RootCAs: c.clusterCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
GetCertificate: nameLookup,
MinVersion: tls.VersionTLS12,
}
if forwarding {
tlsConfig.ServerName = parsedCert.Subject.CommonName
}
return tlsConfig, nil

View File

@ -1,9 +1,9 @@
package vault
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
@ -57,6 +57,11 @@ const (
// leaderPrefixCleanDelay is how long to wait between deletions
// of orphaned leader keys, to prevent slamming the backend.
leaderPrefixCleanDelay = 200 * time.Millisecond
// coreKeyringCanaryPath is used as a canary to indicate to replicated
// clusters that they need to perform a rekey operation synchronously; this
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
coreKeyringCanaryPath = "core/canary-keyring"
)
var (
@ -80,6 +85,12 @@ var (
// step down of the active node, to prevent instantly regrabbing the lock.
// It's var not const so that tests can manipulate it.
manualStepDownSleepPeriod = 10 * time.Second
// Functions only in the Enterprise version
enterprisePostUnseal = enterprisePostUnsealImpl
enterprisePreSeal = enterprisePreSealImpl
startReplication = startReplicationImpl
stopReplication = stopReplicationImpl
)
// ReloadFunc are functions that are called when a reload is requested.
@ -126,6 +137,11 @@ type unlockInformation struct {
// interface for API handlers and is responsible for managing the logical and physical
// backends, router, security barrier, and audit trails.
type Core struct {
// N.B.: This is used to populate a dev token down replication, as
// otherwise, after replication is started, a dev would have to go through
// the generate-root process simply to talk to the new follower cluster.
devToken string
// HABackend may be available depending on the physical backend
ha physical.HABackend
@ -261,7 +277,7 @@ type Core struct {
//
// Name
clusterName string
// Used to modify cluster TLS params
// Used to modify cluster parameters
clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members
@ -308,6 +324,8 @@ type Core struct {
// CoreConfig is used to parameterize a core
type CoreConfig struct {
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
@ -383,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
}
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if !conf.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
@ -400,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
// Construct a new AES-GCM barrier
barrier, err := NewAESGCMBarrier(conf.Physical)
var err error
c.barrier, err = NewAESGCMBarrier(c.physical)
if err != nil {
return nil, fmt.Errorf("barrier setup failed: %v", err)
}
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
barrier: barrier,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the backend in a cache unless disabled
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical
}
@ -796,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
return true, nil
}
masterKey, err := c.unsealPart(config, key)
if err != nil {
return false, err
}
if masterKey != nil {
return c.unsealInternal(masterKey)
}
return false, nil
}
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
// Check if we already have this piece
if c.unlockInfo != nil {
for _, existing := range c.unlockInfo.Parts {
if bytes.Equal(existing, key) {
return false, nil
if subtle.ConstantTimeCompare(existing, key) == 1 {
return nil, nil
}
}
} else {
uuid, err := uuid.GenerateUUID()
if err != nil {
return false, err
return nil, err
}
c.unlockInfo = &unlockInformation{
Nonce: uuid,
@ -821,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
if c.logger.IsDebug() {
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
}
return false, nil
return nil, nil
}
// Best-effort memzero of unlock parts once we're done with them
defer func() {
for i, _ := range c.unlockInfo.Parts {
memzero(c.unlockInfo.Parts[i])
}
c.unlockInfo = nil
}()
// Recover the master key
var masterKey []byte
var err error
if config.SecretThreshold == 1 {
masterKey = c.unlockInfo.Parts[0]
c.unlockInfo = nil
masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
copy(masterKey, c.unlockInfo.Parts[0])
} else {
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
c.unlockInfo = nil
if err != nil {
return false, fmt.Errorf("failed to compute master key: %v", err)
return nil, fmt.Errorf("failed to compute master key: %v", err)
}
}
defer memzero(masterKey)
return c.unsealInternal(masterKey)
return masterKey, nil
}
// This must be called with the state write lock held
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
defer memzero(masterKey)
// Attempt to unlock
if err := c.barrier.Unseal(masterKey); err != nil {
return false, err
@ -860,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
c.logger.Warn("core: vault is sealed")
return false, err
}
if err := c.postUnseal(); err != nil {
c.logger.Error("core: post-unseal setup failed", "error", err)
c.barrier.Seal()
c.logger.Warn("core: vault is sealed")
return false, err
}
c.standby = false
} else {
// Go to standby mode, wait until we are active to unseal
@ -1161,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
}
// HA mode requires us to handle keyring rotation and rekeying
if c.ha != nil {
// We want to reload these from disk so that in case of a rekey we're
@ -1183,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
return err
}
}
if err := enterprisePostUnseal(c); err != nil {
return err
}
if err := c.ensureWrappingKey(); err != nil {
return err
}
@ -1244,6 +1290,7 @@ func (c *Core) preSeal() error {
c.metricsCh = nil
}
var result error
if c.ha != nil {
c.stopClusterListener()
}
@ -1266,6 +1313,10 @@ func (c *Core) preSeal() error {
if err := c.unloadMounts(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
}
if err := enterprisePreSeal(c); err != nil {
result = multierror.Append(result, err)
}
// Purge the backend if supported
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
@ -1274,6 +1325,22 @@ func (c *Core) preSeal() error {
return result
}
func enterprisePostUnsealImpl(c *Core) error {
return nil
}
func enterprisePreSealImpl(c *Core) error {
return nil
}
func startReplicationImpl(c *Core) error {
return nil
}
func stopReplicationImpl(c *Core) error {
return nil
}
// runStandby is a long running routine that is used when an HA backend
// is enabled. It waits until we are leader and switches this Vault to
// active.

View File

@ -22,6 +22,23 @@ var (
protectedPaths = []string{
"core",
}
replicationPaths = []*framework.Path{
&framework.Path{
Pattern: "replication/status",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var state consts.ReplicationState
resp := &logical.Response{
Data: map[string]interface{}{
"mode": state.String(),
},
}
return resp, nil
},
},
},
}
)
func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backend, error) {
@ -675,7 +692,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
},
}
b.Backend.Paths = append(b.Backend.Paths, b.replicationPaths()...)
b.Backend.Paths = append(b.Backend.Paths, replicationPaths...)
b.Backend.Invalidate = b.invalidate

View File

@ -19,6 +19,10 @@ const (
// can only be viewed or modified after an unseal.
coreMountConfigPath = "core/mounts"
// coreLocalMountConfigPath is used to store mount configuration for local
// (non-replicated) mounts
coreLocalMountConfigPath = "core/local-mounts"
// backendBarrierPrefix is the prefix to the UUID used in the
// barrier view for the backends.
backendBarrierPrefix = "logical/"
@ -124,6 +128,7 @@ type MountEntry struct {
UUID string `json:"uuid"` // Barrier view UUID
Config MountConfig `json:"config"` // Configuration related to this mount (but not backend-derived)
Options map[string]string `json:"options"` // Backend options
Local bool `json:"local"` // Local mounts are not replicated or affected by replication
Tainted bool `json:"tainted,omitempty"` // Set as a Write-Ahead flag for unmount/remount
}
@ -147,28 +152,29 @@ func (e *MountEntry) Clone() *MountEntry {
UUID: e.UUID,
Config: e.Config,
Options: optClone,
Local: e.Local,
Tainted: e.Tainted,
}
}
// Mount is used to mount a new backend to the mount table.
func (c *Core) mount(me *MountEntry) error {
func (c *Core) mount(entry *MountEntry) error {
// Ensure we end the path in a slash
if !strings.HasSuffix(me.Path, "/") {
me.Path += "/"
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
}
// Prevent protected paths from being mounted
for _, p := range protectedMounts {
if strings.HasPrefix(me.Path, p) {
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
if strings.HasPrefix(entry.Path, p) {
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", entry.Path))
}
}
// Do not allow more than one instance of a singleton mount
for _, p := range singletonMounts {
if me.Type == p {
return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", me.Type))
if entry.Type == p {
return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", entry.Type))
}
}
@ -176,37 +182,47 @@ func (c *Core) mount(me *MountEntry) error {
defer c.mountsLock.Unlock()
// Verify there is no conflicting mount
if match := c.router.MatchingMount(me.Path); match != "" {
if match := c.router.MatchingMount(entry.Path); match != "" {
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}
// Generate a new UUID and view
meUUID, err := uuid.GenerateUUID()
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
viewPath := backendBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
backend, err := c.newLogicalBackend(entry.Type, sysView, view, nil)
if err != nil {
return err
}
me.UUID = meUUID
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil)
if err != nil {
// Call initialize; this takes care of init tasks that must be run after
// the ignore paths are collected
if err := backend.Initialize(); err != nil {
return err
}
newTable := c.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, me)
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistMounts(newTable); err != nil {
c.logger.Error("core: failed to update mount table", "error", err)
return logical.CodedError(500, "failed to update mount table")
}
c.mounts = newTable
if err := c.router.Mount(backend, me.Path, me, view); err != nil {
if err := c.router.Mount(backend, entry.Path, entry, view); err != nil {
return err
}
if c.logger.IsInfo() {
c.logger.Info("core: successful mount", "path", me.Path, "type", me.Type)
c.logger.Info("core: successful mount", "path", entry.Path, "type", entry.Type)
}
return nil
}
@ -291,6 +307,12 @@ func (c *Core) removeMountEntry(path string) error {
newTable := c.mounts.shallowClone()
newTable.remove(path)
// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
c.logger.Error("core: failed to update mount table", "error", err)
@ -405,12 +427,18 @@ func (c *Core) remount(src, dst string) error {
// loadMounts is invoked as part of postUnseal to load the mount table
func (c *Core) loadMounts() error {
mountTable := &MountTable{}
localMountTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreMountConfigPath)
if err != nil {
c.logger.Error("core: failed to read mount table", "error", err)
return errLoadMountsFailed
}
rawLocal, err := c.barrier.Get(coreLocalMountConfigPath)
if err != nil {
c.logger.Error("core: failed to read local mount table", "error", err)
return errLoadMountsFailed
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
@ -425,6 +453,13 @@ func (c *Core) loadMounts() error {
}
c.mounts = mountTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountTable); err != nil {
c.logger.Error("core: failed to decompress and/or decode the local mount table", "error", err)
return err
}
c.mounts.Entries = append(c.mounts.Entries, localMountTable.Entries...)
}
// Ensure that required entries are loaded, or new ones
// added may never get loaded at all. Note that this
@ -492,8 +527,24 @@ func (c *Core) persistMounts(table *MountTable) error {
}
}
nonLocalMounts := &MountTable{
Type: mountTableType,
}
localMounts := &MountTable{
Type: mountTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localMounts.Entries = append(localMounts.Entries, entry)
} else {
nonLocalMounts.Entries = append(nonLocalMounts.Entries, entry)
}
}
// Encode the mount table into JSON and compress it (lzw).
compressedBytes, err := jsonutil.EncodeJSONAndCompress(table, nil)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress the mount table", "error", err)
return err
@ -510,6 +561,24 @@ func (c *Core) persistMounts(table *MountTable) error {
c.logger.Error("core: failed to persist mount table", "error", err)
return err
}
// Repeat with local mounts
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localMounts, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress the local mount table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalMountConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local mount table", "error", err)
return err
}
return nil
}
@ -532,15 +601,19 @@ func (c *Core) setupMounts() error {
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, barrierPath)
sysView := c.mountEntrySysView(entry)
// Initialize the backend
// Create the new backend
backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil)
if err != nil {
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
return errLoadMountsFailed
}
if err := backend.Initialize(); err != nil {
return err
}
switch entry.Type {
case "system":
c.systemBarrierView = view
@ -616,10 +689,10 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi
// mountEntrySysView creates a logical.SystemView from global and
// mount-specific entries; because this should be called when setting
// up a mountEntry, it doesn't check to ensure that me is not nil
func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
func (c *Core) mountEntrySysView(entry *MountEntry) logical.SystemView {
return dynamicSystemView{
core: c,
mountEntry: me,
mountEntry: entry,
}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/compressutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
)
@ -82,6 +83,96 @@ func TestCore_Mount(t *testing.T) {
}
}
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_Mount_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.mounts = &MountTable{
Type: mountTableType,
Entries: []*MountEntry{
&MountEntry{
Table: mountTableType,
Path: "noop/",
Type: "generic",
UUID: "abcd",
},
&MountEntry{
Table: mountTableType,
Path: "noop2/",
Type: "generic",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupMounts()
if err != nil {
t.Fatal(err)
}
if len(c.mounts.Entries) != 2 {
t.Fatalf("expected two entries, got %d", len(c.mounts.Entries))
}
rawLocal, err := c.barrier.Get(coreLocalMountConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local mounts")
}
localMountsTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil {
t.Fatal(err)
}
if len(localMountsTable.Entries) > 0 {
t.Fatalf("expected no entries in local mount table, got %#v", localMountsTable)
}
c.mounts.Entries[1].Local = true
if err := c.persistMounts(c.mounts); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalMountConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local mount")
}
localMountsTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil {
t.Fatal(err)
}
if len(localMountsTable.Entries) != 1 {
t.Fatalf("expected one entry in local mount table, got %#v", localMountsTable)
}
oldMounts := c.mounts
if err := c.loadMounts(); err != nil {
t.Fatal(err)
}
compEntries := c.mounts.Entries[:0]
// Filter out required mounts
for _, v := range c.mounts.Entries {
if v.Type == "generic" {
compEntries = append(compEntries, v)
}
}
c.mounts.Entries = compEntries
if !reflect.DeepEqual(oldMounts, c.mounts) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldMounts, c.mounts)
}
if len(c.mounts.Entries) != 2 {
t.Fatalf("expected two mount entries, got %#v", localMountsTable)
}
}
func TestCore_Unmount(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t)
existed, err := c.unmount("secret")

View File

@ -2,11 +2,13 @@ package vault
import (
"fmt"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
@ -137,7 +139,13 @@ func (c *Core) setupPolicyStore() error {
view := c.systemBarrierView.SubView(policySubPath)
// Create the policy store
c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c})
sysView := &dynamicSystemView{core: c}
c.policyStore = NewPolicyStore(view, sysView)
if sysView.ReplicationState() == consts.ReplicationSecondary {
// Policies will sync from the primary
return nil
}
// Ensure that the default policy exists, and if not, create it
policy, err := c.policyStore.GetPolicy("default")
@ -173,6 +181,16 @@ func (c *Core) teardownPolicyStore() error {
return nil
}
func (ps *PolicyStore) invalidate(name string) {
if ps.lru == nil {
// Nothing to do if the cache is not used
return
}
// This may come with a prefixed "/" due to joining the file path
ps.lru.Remove(strings.TrimPrefix(name, "/"))
}
// SetPolicy is used to create or update the given policy
func (ps *PolicyStore) SetPolicy(p *Policy) error {
defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())