diff --git a/vault/audit.go b/vault/audit.go index 3df4cd96b..b8974963e 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -2,7 +2,6 @@ package vault import ( "crypto/sha256" - "encoding/json" "errors" "fmt" "strings" @@ -26,6 +25,10 @@ const ( // can only be viewed or modified after an unseal. coreAuditConfigPath = "core/audit" + // coreLocalAuditConfigPath is used to store audit information for local + // (non-replicated) mounts + coreLocalAuditConfigPath = "core/local-audit" + // auditBarrierPrefix is the prefix to the UUID used in the // barrier view for the audit backends. auditBarrierPrefix = "audit/" @@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error { } // Generate a new UUID and view - entryUUID, err := uuid.GenerateUUID() - if err != nil { - return err + if entry.UUID == "" { + entryUUID, err := uuid.GenerateUUID() + if err != nil { + return err + } + entry.UUID = entryUUID } - entry.UUID = entryUUID - view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") + viewPath := auditBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) // Lookup the new backend backend, err := c.newAuditBackend(entry, view, entry.Options) @@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) { c.removeAuditReloadFunc(entry) + // When unmounting all entries the JSON code will load back up from storage + // as a nil slice, which kills tests...just set it nil explicitly + if len(newTable.Entries) == 0 { + newTable.Entries = nil + } + // Update the audit table if err := c.persistAudit(newTable); err != nil { return true, errors.New("failed to update audit table") @@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) { if c.logger.IsInfo() { c.logger.Info("core: disabled audit backend", "path", path) } + return true, nil } // loadAudits is invoked as part of postUnseal to load the audit table func (c *Core) loadAudits() error { auditTable := &MountTable{} + localAuditTable := &MountTable{} // Load the existing audit table raw, err := c.barrier.Get(coreAuditConfigPath) @@ -144,6 +158,11 @@ func (c *Core) loadAudits() error { c.logger.Error("core: failed to read audit table", "error", err) return errLoadAuditFailed } + rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + c.logger.Error("core: failed to read local audit table", "error", err) + return errLoadAuditFailed + } c.auditLock.Lock() defer c.auditLock.Unlock() @@ -155,6 +174,13 @@ func (c *Core) loadAudits() error { } c.audit = auditTable } + if rawLocal != nil { + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + c.logger.Error("core: failed to decode local audit table", "error", err) + return errLoadAuditFailed + } + c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...) + } // Done if we have restored the audit table if c.audit != nil { @@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error { } } + nonLocalAudit := &MountTable{ + Type: auditTableType, + } + + localAudit := &MountTable{ + Type: auditTableType, + } + + for _, entry := range table.Entries { + if entry.Local { + localAudit.Entries = append(localAudit.Entries, entry) + } else { + nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry) + } + } + // Marshal the table - raw, err := json.Marshal(table) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil) if err != nil { - c.logger.Error("core: failed to encode audit table", "error", err) + c.logger.Error("core: failed to encode and/or compress audit table", "error", err) return err } // Create an entry entry := &Entry{ Key: coreAuditConfigPath, - Value: raw, + Value: compressedBytes, } // Write to the physical backend @@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error { c.logger.Error("core: failed to persist audit table", "error", err) return err } + + // Repeat with local audit + compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil) + if err != nil { + c.logger.Error("core: failed to encode and/or compress local audit table", "error", err) + return err + } + + entry = &Entry{ + Key: coreLocalAuditConfigPath, + Value: compressedBytes, + } + + if err := c.barrier.Put(entry); err != nil { + c.logger.Error("core: failed to persist local audit table", "error", err) + return err + } + return nil } @@ -236,7 +296,8 @@ func (c *Core) setupAudits() error { for _, entry := range c.audit.Entries { // Create a barrier view using the UUID - view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") + viewPath := auditBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) // Initialize the backend audit, err := c.newAuditBackend(entry, view, entry.Options) diff --git a/vault/audit_test.go b/vault/audit_test.go index e1cd51cf9..491be4915 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/logical" log "github.com/mgutz/logxi/v1" @@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) { } } +// Test that the local table actually gets populated as expected with local +// entries, and that upon reading the entries from both are recombined +// correctly +func TestCore_EnableAudit_Local(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil + } + + c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return nil, fmt.Errorf("failing enabling") + } + + c.audit = &MountTable{ + Type: auditTableType, + Entries: []*MountEntry{ + &MountEntry{ + Table: auditTableType, + Path: "noop/", + Type: "noop", + UUID: "abcd", + }, + &MountEntry{ + Table: auditTableType, + Path: "noop2/", + Type: "noop", + UUID: "bcde", + }, + }, + } + + // Both should set up successfully + err := c.setupAudits() + if err != nil { + t.Fatal(err) + } + + rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local audit") + } + localAuditTable := &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + t.Fatal(err) + } + if len(localAuditTable.Entries) > 0 { + t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable) + } + + c.audit.Entries[1].Local = true + if err := c.persistAudit(c.audit); err != nil { + t.Fatal(err) + } + + rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local audit") + } + localAuditTable = &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + t.Fatal(err) + } + if len(localAuditTable.Entries) != 1 { + t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable) + } + + oldAudit := c.audit + if err := c.loadAudits(); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(oldAudit, c.audit) { + t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit) + } + + if len(c.audit.Entries) != 2 { + t.Fatalf("expected two audit entries, got %#v", localAuditTable) + } +} + func TestCore_DisableAudit(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { @@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) { // Verify matching mount tables if !reflect.DeepEqual(c.audit, c2.audit) { - t.Fatalf("mismatch: %v %v", c.audit, c2.audit) + t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit) } } diff --git a/vault/cluster.go b/vault/cluster.go index 732080759..1686c09ce 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -43,7 +43,7 @@ var ( // This can be one of a few key types so the different params may or may not be filled type clusterKeyParams struct { - Type string `json:"type"` + Type string `json:"type" structs:"type" mapstructure:"type"` X *big.Int `json:"x" structs:"x" mapstructure:"x"` Y *big.Int `json:"y" structs:"y" mapstructure:"y"` D *big.Int `json:"d" structs:"d" mapstructure:"d"` @@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() { c.logger.Info("core/stopClusterListener: success") } -// ClusterTLSConfig generates a TLS configuration based on the local cluster -// key and cert. +// ClusterTLSConfig generates a TLS configuration based on the local/replicated +// cluster key and cert. func (c *Core) ClusterTLSConfig() (*tls.Config, error) { cluster, err := c.Cluster() if err != nil { return nil, err } if cluster == nil { - return nil, fmt.Errorf("cluster information is nil") + return nil, fmt.Errorf("local cluster information is nil") } // Prevent data races with the TLS parameters c.clusterParamsLock.Lock() defer c.clusterParamsLock.Unlock() - if c.localClusterCert == nil || len(c.localClusterCert) == 0 { - return nil, fmt.Errorf("cluster certificate is nil") + forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0 + + var parsedCert *x509.Certificate + if forwarding { + parsedCert, err = x509.ParseCertificate(c.localClusterCert) + if err != nil { + return nil, fmt.Errorf("error parsing local cluster certificate: %v", err) + } + + // This is idempotent, so be sure it's been added + c.clusterCertPool.AddCert(parsedCert) } - parsedCert, err := x509.ParseCertificate(c.localClusterCert) - if err != nil { - return nil, fmt.Errorf("error parsing local cluster certificate: %v", err) - } + nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + c.clusterParamsLock.RLock() + defer c.clusterParamsLock.RUnlock() - // This is idempotent, so be sure it's been added - c.clusterCertPool.AddCert(parsedCert) - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{ - tls.Certificate{ + if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName { + return &tls.Certificate{ Certificate: [][]byte{c.localClusterCert}, PrivateKey: c.localClusterPrivateKey, - }, - }, - RootCAs: c.clusterCertPool, - ServerName: parsedCert.Subject.CommonName, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: c.clusterCertPool, - MinVersion: tls.VersionTLS12, + }, nil + } + + return nil, nil + } + + var clientCertificates []tls.Certificate + if forwarding { + clientCertificates = append(clientCertificates, tls.Certificate{ + Certificate: [][]byte{c.localClusterCert}, + PrivateKey: c.localClusterPrivateKey, + }) + } + + tlsConfig := &tls.Config{ + // We need this here for the client side + Certificates: clientCertificates, + RootCAs: c.clusterCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: c.clusterCertPool, + GetCertificate: nameLookup, + MinVersion: tls.VersionTLS12, + } + if forwarding { + tlsConfig.ServerName = parsedCert.Subject.CommonName } return tlsConfig, nil diff --git a/vault/core.go b/vault/core.go index f3b9bf696..02a1c010c 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1,9 +1,9 @@ package vault import ( - "bytes" "crypto" "crypto/ecdsa" + "crypto/subtle" "crypto/x509" "errors" "fmt" @@ -57,6 +57,11 @@ const ( // leaderPrefixCleanDelay is how long to wait between deletions // of orphaned leader keys, to prevent slamming the backend. leaderPrefixCleanDelay = 200 * time.Millisecond + + // coreKeyringCanaryPath is used as a canary to indicate to replicated + // clusters that they need to perform a rekey operation synchronously; this + // isn't keyring-canary to avoid ignoring it when ignoring core/keyring + coreKeyringCanaryPath = "core/canary-keyring" ) var ( @@ -80,6 +85,12 @@ var ( // step down of the active node, to prevent instantly regrabbing the lock. // It's var not const so that tests can manipulate it. manualStepDownSleepPeriod = 10 * time.Second + + // Functions only in the Enterprise version + enterprisePostUnseal = enterprisePostUnsealImpl + enterprisePreSeal = enterprisePreSealImpl + startReplication = startReplicationImpl + stopReplication = stopReplicationImpl ) // ReloadFunc are functions that are called when a reload is requested. @@ -126,6 +137,11 @@ type unlockInformation struct { // interface for API handlers and is responsible for managing the logical and physical // backends, router, security barrier, and audit trails. type Core struct { + // N.B.: This is used to populate a dev token down replication, as + // otherwise, after replication is started, a dev would have to go through + // the generate-root process simply to talk to the new follower cluster. + devToken string + // HABackend may be available depending on the physical backend ha physical.HABackend @@ -261,7 +277,7 @@ type Core struct { // // Name clusterName string - // Used to modify cluster TLS params + // Used to modify cluster parameters clusterParamsLock sync.RWMutex // The private key stored in the barrier used for establishing // mutually-authenticated connections between Vault cluster members @@ -308,6 +324,8 @@ type Core struct { // CoreConfig is used to parameterize a core type CoreConfig struct { + DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"` + LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"` CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"` @@ -383,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) { conf.Logger = logformat.NewVaultLogger(log.LevelTrace) } + // Setup the core + c := &Core{ + redirectAddr: conf.RedirectAddr, + clusterAddr: conf.ClusterAddr, + physical: conf.Physical, + seal: conf.Seal, + router: NewRouter(), + sealed: true, + standby: true, + logger: conf.Logger, + defaultLeaseTTL: conf.DefaultLeaseTTL, + maxLeaseTTL: conf.MaxLeaseTTL, + cachingDisabled: conf.DisableCache, + clusterName: conf.ClusterName, + clusterCertPool: x509.NewCertPool(), + clusterListenerShutdownCh: make(chan struct{}), + clusterListenerShutdownSuccessCh: make(chan struct{}), + } + + // Wrap the physical backend in a cache layer if enabled and not already wrapped + if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache { + c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger) + } + if !conf.DisableMlock { // Ensure our memory usage is locked into physical RAM if err := mlock.LockMemory(); err != nil { @@ -400,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) { } // Construct a new AES-GCM barrier - barrier, err := NewAESGCMBarrier(conf.Physical) + var err error + c.barrier, err = NewAESGCMBarrier(c.physical) if err != nil { return nil, fmt.Errorf("barrier setup failed: %v", err) } - // Setup the core - c := &Core{ - redirectAddr: conf.RedirectAddr, - clusterAddr: conf.ClusterAddr, - physical: conf.Physical, - seal: conf.Seal, - barrier: barrier, - router: NewRouter(), - sealed: true, - standby: true, - logger: conf.Logger, - defaultLeaseTTL: conf.DefaultLeaseTTL, - maxLeaseTTL: conf.MaxLeaseTTL, - cachingDisabled: conf.DisableCache, - clusterName: conf.ClusterName, - clusterCertPool: x509.NewCertPool(), - clusterListenerShutdownCh: make(chan struct{}), - clusterListenerShutdownSuccessCh: make(chan struct{}), - } - - // Wrap the backend in a cache unless disabled - if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache { - c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger) - } - if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { c.ha = conf.HAPhysical } @@ -796,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) { return true, nil } + masterKey, err := c.unsealPart(config, key) + if err != nil { + return false, err + } + if masterKey != nil { + return c.unsealInternal(masterKey) + } + + return false, nil +} + +func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) { // Check if we already have this piece if c.unlockInfo != nil { for _, existing := range c.unlockInfo.Parts { - if bytes.Equal(existing, key) { - return false, nil + if subtle.ConstantTimeCompare(existing, key) == 1 { + return nil, nil } } } else { uuid, err := uuid.GenerateUUID() if err != nil { - return false, err + return nil, err } c.unlockInfo = &unlockInformation{ Nonce: uuid, @@ -821,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) { if c.logger.IsDebug() { c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce) } - return false, nil + return nil, nil } + // Best-effort memzero of unlock parts once we're done with them + defer func() { + for i, _ := range c.unlockInfo.Parts { + memzero(c.unlockInfo.Parts[i]) + } + c.unlockInfo = nil + }() + // Recover the master key var masterKey []byte + var err error if config.SecretThreshold == 1 { - masterKey = c.unlockInfo.Parts[0] - c.unlockInfo = nil + masterKey = make([]byte, len(c.unlockInfo.Parts[0])) + copy(masterKey, c.unlockInfo.Parts[0]) } else { masterKey, err = shamir.Combine(c.unlockInfo.Parts) - c.unlockInfo = nil if err != nil { - return false, fmt.Errorf("failed to compute master key: %v", err) + return nil, fmt.Errorf("failed to compute master key: %v", err) } } - defer memzero(masterKey) - return c.unsealInternal(masterKey) + return masterKey, nil } +// This must be called with the state write lock held func (c *Core) unsealInternal(masterKey []byte) (bool, error) { + defer memzero(masterKey) + // Attempt to unlock if err := c.barrier.Unseal(masterKey); err != nil { return false, err @@ -860,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) { c.logger.Warn("core: vault is sealed") return false, err } + if err := c.postUnseal(); err != nil { c.logger.Error("core: post-unseal setup failed", "error", err) c.barrier.Seal() c.logger.Warn("core: vault is sealed") return false, err } + c.standby = false } else { // Go to standby mode, wait until we are active to unseal @@ -1161,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) { if purgable, ok := c.physical.(physical.Purgable); ok { purgable.Purge() } + // HA mode requires us to handle keyring rotation and rekeying if c.ha != nil { // We want to reload these from disk so that in case of a rekey we're @@ -1183,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) { return err } } + if err := enterprisePostUnseal(c); err != nil { + return err + } if err := c.ensureWrappingKey(); err != nil { return err } @@ -1244,6 +1290,7 @@ func (c *Core) preSeal() error { c.metricsCh = nil } var result error + if c.ha != nil { c.stopClusterListener() } @@ -1266,6 +1313,10 @@ func (c *Core) preSeal() error { if err := c.unloadMounts(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err)) } + if err := enterprisePreSeal(c); err != nil { + result = multierror.Append(result, err) + } + // Purge the backend if supported if purgable, ok := c.physical.(physical.Purgable); ok { purgable.Purge() @@ -1274,6 +1325,22 @@ func (c *Core) preSeal() error { return result } +func enterprisePostUnsealImpl(c *Core) error { + return nil +} + +func enterprisePreSealImpl(c *Core) error { + return nil +} + +func startReplicationImpl(c *Core) error { + return nil +} + +func stopReplicationImpl(c *Core) error { + return nil +} + // runStandby is a long running routine that is used when an HA backend // is enabled. It waits until we are leader and switches this Vault to // active. diff --git a/vault/logical_system.go b/vault/logical_system.go index e0f116321..b3756bcc9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -22,6 +22,23 @@ var ( protectedPaths = []string{ "core", } + + replicationPaths = []*framework.Path{ + &framework.Path{ + Pattern: "replication/status", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + var state consts.ReplicationState + resp := &logical.Response{ + Data: map[string]interface{}{ + "mode": state.String(), + }, + } + return resp, nil + }, + }, + }, + } ) func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backend, error) { @@ -675,7 +692,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen }, } - b.Backend.Paths = append(b.Backend.Paths, b.replicationPaths()...) + b.Backend.Paths = append(b.Backend.Paths, replicationPaths...) b.Backend.Invalidate = b.invalidate diff --git a/vault/mount.go b/vault/mount.go index 072583ae9..a7e29a2a9 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -19,6 +19,10 @@ const ( // can only be viewed or modified after an unseal. coreMountConfigPath = "core/mounts" + // coreLocalMountConfigPath is used to store mount configuration for local + // (non-replicated) mounts + coreLocalMountConfigPath = "core/local-mounts" + // backendBarrierPrefix is the prefix to the UUID used in the // barrier view for the backends. backendBarrierPrefix = "logical/" @@ -124,6 +128,7 @@ type MountEntry struct { UUID string `json:"uuid"` // Barrier view UUID Config MountConfig `json:"config"` // Configuration related to this mount (but not backend-derived) Options map[string]string `json:"options"` // Backend options + Local bool `json:"local"` // Local mounts are not replicated or affected by replication Tainted bool `json:"tainted,omitempty"` // Set as a Write-Ahead flag for unmount/remount } @@ -147,28 +152,29 @@ func (e *MountEntry) Clone() *MountEntry { UUID: e.UUID, Config: e.Config, Options: optClone, + Local: e.Local, Tainted: e.Tainted, } } // Mount is used to mount a new backend to the mount table. -func (c *Core) mount(me *MountEntry) error { +func (c *Core) mount(entry *MountEntry) error { // Ensure we end the path in a slash - if !strings.HasSuffix(me.Path, "/") { - me.Path += "/" + if !strings.HasSuffix(entry.Path, "/") { + entry.Path += "/" } // Prevent protected paths from being mounted for _, p := range protectedMounts { - if strings.HasPrefix(me.Path, p) { - return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path)) + if strings.HasPrefix(entry.Path, p) { + return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", entry.Path)) } } // Do not allow more than one instance of a singleton mount for _, p := range singletonMounts { - if me.Type == p { - return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", me.Type)) + if entry.Type == p { + return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", entry.Type)) } } @@ -176,37 +182,47 @@ func (c *Core) mount(me *MountEntry) error { defer c.mountsLock.Unlock() // Verify there is no conflicting mount - if match := c.router.MatchingMount(me.Path); match != "" { + if match := c.router.MatchingMount(entry.Path); match != "" { return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match)) } // Generate a new UUID and view - meUUID, err := uuid.GenerateUUID() + if entry.UUID == "" { + entryUUID, err := uuid.GenerateUUID() + if err != nil { + return err + } + entry.UUID = entryUUID + } + viewPath := backendBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) + sysView := c.mountEntrySysView(entry) + + backend, err := c.newLogicalBackend(entry.Type, sysView, view, nil) if err != nil { return err } - me.UUID = meUUID - view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/") - backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil) - if err != nil { + // Call initialize; this takes care of init tasks that must be run after + // the ignore paths are collected + if err := backend.Initialize(); err != nil { return err } newTable := c.mounts.shallowClone() - newTable.Entries = append(newTable.Entries, me) + newTable.Entries = append(newTable.Entries, entry) if err := c.persistMounts(newTable); err != nil { c.logger.Error("core: failed to update mount table", "error", err) return logical.CodedError(500, "failed to update mount table") } c.mounts = newTable - if err := c.router.Mount(backend, me.Path, me, view); err != nil { + if err := c.router.Mount(backend, entry.Path, entry, view); err != nil { return err } if c.logger.IsInfo() { - c.logger.Info("core: successful mount", "path", me.Path, "type", me.Type) + c.logger.Info("core: successful mount", "path", entry.Path, "type", entry.Type) } return nil } @@ -291,6 +307,12 @@ func (c *Core) removeMountEntry(path string) error { newTable := c.mounts.shallowClone() newTable.remove(path) + // When unmounting all entries the JSON code will load back up from storage + // as a nil slice, which kills tests...just set it nil explicitly + if len(newTable.Entries) == 0 { + newTable.Entries = nil + } + // Update the mount table if err := c.persistMounts(newTable); err != nil { c.logger.Error("core: failed to update mount table", "error", err) @@ -405,12 +427,18 @@ func (c *Core) remount(src, dst string) error { // loadMounts is invoked as part of postUnseal to load the mount table func (c *Core) loadMounts() error { mountTable := &MountTable{} + localMountTable := &MountTable{} // Load the existing mount table raw, err := c.barrier.Get(coreMountConfigPath) if err != nil { c.logger.Error("core: failed to read mount table", "error", err) return errLoadMountsFailed } + rawLocal, err := c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + c.logger.Error("core: failed to read local mount table", "error", err) + return errLoadMountsFailed + } c.mountsLock.Lock() defer c.mountsLock.Unlock() @@ -425,6 +453,13 @@ func (c *Core) loadMounts() error { } c.mounts = mountTable } + if rawLocal != nil { + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountTable); err != nil { + c.logger.Error("core: failed to decompress and/or decode the local mount table", "error", err) + return err + } + c.mounts.Entries = append(c.mounts.Entries, localMountTable.Entries...) + } // Ensure that required entries are loaded, or new ones // added may never get loaded at all. Note that this @@ -492,8 +527,24 @@ func (c *Core) persistMounts(table *MountTable) error { } } + nonLocalMounts := &MountTable{ + Type: mountTableType, + } + + localMounts := &MountTable{ + Type: mountTableType, + } + + for _, entry := range table.Entries { + if entry.Local { + localMounts.Entries = append(localMounts.Entries, entry) + } else { + nonLocalMounts.Entries = append(nonLocalMounts.Entries, entry) + } + } + // Encode the mount table into JSON and compress it (lzw). - compressedBytes, err := jsonutil.EncodeJSONAndCompress(table, nil) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil) if err != nil { c.logger.Error("core: failed to encode and/or compress the mount table", "error", err) return err @@ -510,6 +561,24 @@ func (c *Core) persistMounts(table *MountTable) error { c.logger.Error("core: failed to persist mount table", "error", err) return err } + + // Repeat with local mounts + compressedBytes, err = jsonutil.EncodeJSONAndCompress(localMounts, nil) + if err != nil { + c.logger.Error("core: failed to encode and/or compress the local mount table", "error", err) + return err + } + + entry = &Entry{ + Key: coreLocalMountConfigPath, + Value: compressedBytes, + } + + if err := c.barrier.Put(entry); err != nil { + c.logger.Error("core: failed to persist local mount table", "error", err) + return err + } + return nil } @@ -532,15 +601,19 @@ func (c *Core) setupMounts() error { // Create a barrier view using the UUID view = NewBarrierView(c.barrier, barrierPath) - + sysView := c.mountEntrySysView(entry) // Initialize the backend // Create the new backend - backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil) + backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil) if err != nil { c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) return errLoadMountsFailed } + if err := backend.Initialize(); err != nil { + return err + } + switch entry.Type { case "system": c.systemBarrierView = view @@ -616,10 +689,10 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi // mountEntrySysView creates a logical.SystemView from global and // mount-specific entries; because this should be called when setting // up a mountEntry, it doesn't check to ensure that me is not nil -func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView { +func (c *Core) mountEntrySysView(entry *MountEntry) logical.SystemView { return dynamicSystemView{ core: c, - mountEntry: me, + mountEntry: entry, } } diff --git a/vault/mount_test.go b/vault/mount_test.go index dd6ef5944..a00d37945 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/compressutil" + "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/logical" ) @@ -82,6 +83,96 @@ func TestCore_Mount(t *testing.T) { } } +// Test that the local table actually gets populated as expected with local +// entries, and that upon reading the entries from both are recombined +// correctly +func TestCore_Mount_Local(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + c.mounts = &MountTable{ + Type: mountTableType, + Entries: []*MountEntry{ + &MountEntry{ + Table: mountTableType, + Path: "noop/", + Type: "generic", + UUID: "abcd", + }, + &MountEntry{ + Table: mountTableType, + Path: "noop2/", + Type: "generic", + UUID: "bcde", + }, + }, + } + + // Both should set up successfully + err := c.setupMounts() + if err != nil { + t.Fatal(err) + } + if len(c.mounts.Entries) != 2 { + t.Fatalf("expected two entries, got %d", len(c.mounts.Entries)) + } + + rawLocal, err := c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local mounts") + } + localMountsTable := &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil { + t.Fatal(err) + } + if len(localMountsTable.Entries) > 0 { + t.Fatalf("expected no entries in local mount table, got %#v", localMountsTable) + } + + c.mounts.Entries[1].Local = true + if err := c.persistMounts(c.mounts); err != nil { + t.Fatal(err) + } + + rawLocal, err = c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local mount") + } + localMountsTable = &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil { + t.Fatal(err) + } + if len(localMountsTable.Entries) != 1 { + t.Fatalf("expected one entry in local mount table, got %#v", localMountsTable) + } + + oldMounts := c.mounts + if err := c.loadMounts(); err != nil { + t.Fatal(err) + } + compEntries := c.mounts.Entries[:0] + // Filter out required mounts + for _, v := range c.mounts.Entries { + if v.Type == "generic" { + compEntries = append(compEntries, v) + } + } + c.mounts.Entries = compEntries + + if !reflect.DeepEqual(oldMounts, c.mounts) { + t.Fatalf("expected\n%#v\ngot\n%#v\n", oldMounts, c.mounts) + } + + if len(c.mounts.Entries) != 2 { + t.Fatalf("expected two mount entries, got %#v", localMountsTable) + } +} + func TestCore_Unmount(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) existed, err := c.unmount("secret") diff --git a/vault/policy_store.go b/vault/policy_store.go index 873a8dde8..8200f22cd 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -2,11 +2,13 @@ package vault import ( "fmt" + "strings" "time" "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) @@ -137,7 +139,13 @@ func (c *Core) setupPolicyStore() error { view := c.systemBarrierView.SubView(policySubPath) // Create the policy store - c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c}) + sysView := &dynamicSystemView{core: c} + c.policyStore = NewPolicyStore(view, sysView) + + if sysView.ReplicationState() == consts.ReplicationSecondary { + // Policies will sync from the primary + return nil + } // Ensure that the default policy exists, and if not, create it policy, err := c.policyStore.GetPolicy("default") @@ -173,6 +181,16 @@ func (c *Core) teardownPolicyStore() error { return nil } +func (ps *PolicyStore) invalidate(name string) { + if ps.lru == nil { + // Nothing to do if the cache is not used + return + } + + // This may come with a prefixed "/" due to joining the file path + ps.lru.Remove(strings.TrimPrefix(name, "/")) +} + // SetPolicy is used to create or update the given policy func (ps *PolicyStore) SetPolicy(p *Policy) error { defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())