More porting from rep (#2389)
* More porting from rep * Address feedback
This commit is contained in:
parent
c81582fea0
commit
494b4c844b
|
@ -2,7 +2,6 @@ package vault
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -26,6 +25,10 @@ const (
|
|||
// can only be viewed or modified after an unseal.
|
||||
coreAuditConfigPath = "core/audit"
|
||||
|
||||
// coreLocalAuditConfigPath is used to store audit information for local
|
||||
// (non-replicated) mounts
|
||||
coreLocalAuditConfigPath = "core/local-audit"
|
||||
|
||||
// auditBarrierPrefix is the prefix to the UUID used in the
|
||||
// barrier view for the audit backends.
|
||||
auditBarrierPrefix = "audit/"
|
||||
|
@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
|||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
if entry.UUID == "" {
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
|
||||
// Lookup the new backend
|
||||
backend, err := c.newAuditBackend(entry, view, entry.Options)
|
||||
|
@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
|||
|
||||
c.removeAuditReloadFunc(entry)
|
||||
|
||||
// When unmounting all entries the JSON code will load back up from storage
|
||||
// as a nil slice, which kills tests...just set it nil explicitly
|
||||
if len(newTable.Entries) == 0 {
|
||||
newTable.Entries = nil
|
||||
}
|
||||
|
||||
// Update the audit table
|
||||
if err := c.persistAudit(newTable); err != nil {
|
||||
return true, errors.New("failed to update audit table")
|
||||
|
@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
|||
if c.logger.IsInfo() {
|
||||
c.logger.Info("core: disabled audit backend", "path", path)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// loadAudits is invoked as part of postUnseal to load the audit table
|
||||
func (c *Core) loadAudits() error {
|
||||
auditTable := &MountTable{}
|
||||
localAuditTable := &MountTable{}
|
||||
|
||||
// Load the existing audit table
|
||||
raw, err := c.barrier.Get(coreAuditConfigPath)
|
||||
|
@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
|
|||
c.logger.Error("core: failed to read audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read local audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
|
|||
}
|
||||
c.audit = auditTable
|
||||
}
|
||||
if rawLocal != nil {
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
c.logger.Error("core: failed to decode local audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
|
||||
}
|
||||
|
||||
// Done if we have restored the audit table
|
||||
if c.audit != nil {
|
||||
|
@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
|
|||
}
|
||||
}
|
||||
|
||||
nonLocalAudit := &MountTable{
|
||||
Type: auditTableType,
|
||||
}
|
||||
|
||||
localAudit := &MountTable{
|
||||
Type: auditTableType,
|
||||
}
|
||||
|
||||
for _, entry := range table.Entries {
|
||||
if entry.Local {
|
||||
localAudit.Entries = append(localAudit.Entries, entry)
|
||||
} else {
|
||||
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the table
|
||||
raw, err := json.Marshal(table)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode audit table", "error", err)
|
||||
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreAuditConfigPath,
|
||||
Value: raw,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
// Write to the physical backend
|
||||
|
@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
|
|||
c.logger.Error("core: failed to persist audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Repeat with local audit
|
||||
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
entry = &Entry{
|
||||
Key: coreLocalAuditConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(entry); err != nil {
|
||||
c.logger.Error("core: failed to persist local audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
|
|||
|
||||
for _, entry := range c.audit.Entries {
|
||||
// Create a barrier view using the UUID
|
||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
|
||||
// Initialize the backend
|
||||
audit, err := c.newAuditBackend(entry, view, entry.Options)
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that the local table actually gets populated as expected with local
|
||||
// entries, and that upon reading the entries from both are recombined
|
||||
// correctly
|
||||
func TestCore_EnableAudit_Local(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return nil, fmt.Errorf("failing enabling")
|
||||
}
|
||||
|
||||
c.audit = &MountTable{
|
||||
Type: auditTableType,
|
||||
Entries: []*MountEntry{
|
||||
&MountEntry{
|
||||
Table: auditTableType,
|
||||
Path: "noop/",
|
||||
Type: "noop",
|
||||
UUID: "abcd",
|
||||
},
|
||||
&MountEntry{
|
||||
Table: auditTableType,
|
||||
Path: "noop2/",
|
||||
Type: "noop",
|
||||
UUID: "bcde",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Both should set up successfully
|
||||
err := c.setupAudits()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local audit")
|
||||
}
|
||||
localAuditTable := &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localAuditTable.Entries) > 0 {
|
||||
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
|
||||
}
|
||||
|
||||
c.audit.Entries[1].Local = true
|
||||
if err := c.persistAudit(c.audit); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local audit")
|
||||
}
|
||||
localAuditTable = &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localAuditTable.Entries) != 1 {
|
||||
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
|
||||
}
|
||||
|
||||
oldAudit := c.audit
|
||||
if err := c.loadAudits(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(oldAudit, c.audit) {
|
||||
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
|
||||
}
|
||||
|
||||
if len(c.audit.Entries) != 2 {
|
||||
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_DisableAudit(t *testing.T) {
|
||||
c, keys, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
|
@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
|
|||
|
||||
// Verify matching mount tables
|
||||
if !reflect.DeepEqual(c.audit, c2.audit) {
|
||||
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
|
||||
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ var (
|
|||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
|
@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
|
|||
c.logger.Info("core/stopClusterListener: success")
|
||||
}
|
||||
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local cluster
|
||||
// key and cert.
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
|
||||
// cluster key and cert.
|
||||
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cluster == nil {
|
||||
return nil, fmt.Errorf("cluster information is nil")
|
||||
return nil, fmt.Errorf("local cluster information is nil")
|
||||
}
|
||||
|
||||
// Prevent data races with the TLS parameters
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
|
||||
return nil, fmt.Errorf("cluster certificate is nil")
|
||||
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
|
||||
|
||||
var parsedCert *x509.Certificate
|
||||
if forwarding {
|
||||
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.clusterCertPool.AddCert(parsedCert)
|
||||
}
|
||||
|
||||
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
c.clusterParamsLock.RLock()
|
||||
defer c.clusterParamsLock.RUnlock()
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.clusterCertPool.AddCert(parsedCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
tls.Certificate{
|
||||
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{c.localClusterCert},
|
||||
PrivateKey: c.localClusterPrivateKey,
|
||||
},
|
||||
},
|
||||
RootCAs: c.clusterCertPool,
|
||||
ServerName: parsedCert.Subject.CommonName,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.clusterCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var clientCertificates []tls.Certificate
|
||||
if forwarding {
|
||||
clientCertificates = append(clientCertificates, tls.Certificate{
|
||||
Certificate: [][]byte{c.localClusterCert},
|
||||
PrivateKey: c.localClusterPrivateKey,
|
||||
})
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
// We need this here for the client side
|
||||
Certificates: clientCertificates,
|
||||
RootCAs: c.clusterCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.clusterCertPool,
|
||||
GetCertificate: nameLookup,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
if forwarding {
|
||||
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
|
|
143
vault/core.go
143
vault/core.go
|
@ -1,9 +1,9 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -57,6 +57,11 @@ const (
|
|||
// leaderPrefixCleanDelay is how long to wait between deletions
|
||||
// of orphaned leader keys, to prevent slamming the backend.
|
||||
leaderPrefixCleanDelay = 200 * time.Millisecond
|
||||
|
||||
// coreKeyringCanaryPath is used as a canary to indicate to replicated
|
||||
// clusters that they need to perform a rekey operation synchronously; this
|
||||
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
|
||||
coreKeyringCanaryPath = "core/canary-keyring"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -80,6 +85,12 @@ var (
|
|||
// step down of the active node, to prevent instantly regrabbing the lock.
|
||||
// It's var not const so that tests can manipulate it.
|
||||
manualStepDownSleepPeriod = 10 * time.Second
|
||||
|
||||
// Functions only in the Enterprise version
|
||||
enterprisePostUnseal = enterprisePostUnsealImpl
|
||||
enterprisePreSeal = enterprisePreSealImpl
|
||||
startReplication = startReplicationImpl
|
||||
stopReplication = stopReplicationImpl
|
||||
)
|
||||
|
||||
// ReloadFunc are functions that are called when a reload is requested.
|
||||
|
@ -126,6 +137,11 @@ type unlockInformation struct {
|
|||
// interface for API handlers and is responsible for managing the logical and physical
|
||||
// backends, router, security barrier, and audit trails.
|
||||
type Core struct {
|
||||
// N.B.: This is used to populate a dev token down replication, as
|
||||
// otherwise, after replication is started, a dev would have to go through
|
||||
// the generate-root process simply to talk to the new follower cluster.
|
||||
devToken string
|
||||
|
||||
// HABackend may be available depending on the physical backend
|
||||
ha physical.HABackend
|
||||
|
||||
|
@ -261,7 +277,7 @@ type Core struct {
|
|||
//
|
||||
// Name
|
||||
clusterName string
|
||||
// Used to modify cluster TLS params
|
||||
// Used to modify cluster parameters
|
||||
clusterParamsLock sync.RWMutex
|
||||
// The private key stored in the barrier used for establishing
|
||||
// mutually-authenticated connections between Vault cluster members
|
||||
|
@ -308,6 +324,8 @@ type Core struct {
|
|||
|
||||
// CoreConfig is used to parameterize a core
|
||||
type CoreConfig struct {
|
||||
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
|
||||
|
||||
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
|
||||
|
||||
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
|
||||
|
@ -383,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
|
||||
}
|
||||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterCertPool: x509.NewCertPool(),
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Wrap the physical backend in a cache layer if enabled and not already wrapped
|
||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
|
||||
if !conf.DisableMlock {
|
||||
// Ensure our memory usage is locked into physical RAM
|
||||
if err := mlock.LockMemory(); err != nil {
|
||||
|
@ -400,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
}
|
||||
|
||||
// Construct a new AES-GCM barrier
|
||||
barrier, err := NewAESGCMBarrier(conf.Physical)
|
||||
var err error
|
||||
c.barrier, err = NewAESGCMBarrier(c.physical)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
barrier: barrier,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterCertPool: x509.NewCertPool(),
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Wrap the backend in a cache unless disabled
|
||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
c.ha = conf.HAPhysical
|
||||
}
|
||||
|
@ -796,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
masterKey, err := c.unsealPart(config, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if masterKey != nil {
|
||||
return c.unsealInternal(masterKey)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
|
||||
// Check if we already have this piece
|
||||
if c.unlockInfo != nil {
|
||||
for _, existing := range c.unlockInfo.Parts {
|
||||
if bytes.Equal(existing, key) {
|
||||
return false, nil
|
||||
if subtle.ConstantTimeCompare(existing, key) == 1 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, err
|
||||
}
|
||||
c.unlockInfo = &unlockInformation{
|
||||
Nonce: uuid,
|
||||
|
@ -821,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
|
||||
}
|
||||
return false, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Best-effort memzero of unlock parts once we're done with them
|
||||
defer func() {
|
||||
for i, _ := range c.unlockInfo.Parts {
|
||||
memzero(c.unlockInfo.Parts[i])
|
||||
}
|
||||
c.unlockInfo = nil
|
||||
}()
|
||||
|
||||
// Recover the master key
|
||||
var masterKey []byte
|
||||
var err error
|
||||
if config.SecretThreshold == 1 {
|
||||
masterKey = c.unlockInfo.Parts[0]
|
||||
c.unlockInfo = nil
|
||||
masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
|
||||
copy(masterKey, c.unlockInfo.Parts[0])
|
||||
} else {
|
||||
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
|
||||
c.unlockInfo = nil
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to compute master key: %v", err)
|
||||
return nil, fmt.Errorf("failed to compute master key: %v", err)
|
||||
}
|
||||
}
|
||||
defer memzero(masterKey)
|
||||
|
||||
return c.unsealInternal(masterKey)
|
||||
return masterKey, nil
|
||||
}
|
||||
|
||||
// This must be called with the state write lock held
|
||||
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
||||
defer memzero(masterKey)
|
||||
|
||||
// Attempt to unlock
|
||||
if err := c.barrier.Unseal(masterKey); err != nil {
|
||||
return false, err
|
||||
|
@ -860,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
|||
c.logger.Warn("core: vault is sealed")
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := c.postUnseal(); err != nil {
|
||||
c.logger.Error("core: post-unseal setup failed", "error", err)
|
||||
c.barrier.Seal()
|
||||
c.logger.Warn("core: vault is sealed")
|
||||
return false, err
|
||||
}
|
||||
|
||||
c.standby = false
|
||||
} else {
|
||||
// Go to standby mode, wait until we are active to unseal
|
||||
|
@ -1161,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
}
|
||||
|
||||
// HA mode requires us to handle keyring rotation and rekeying
|
||||
if c.ha != nil {
|
||||
// We want to reload these from disk so that in case of a rekey we're
|
||||
|
@ -1183,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
if err := enterprisePostUnseal(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.ensureWrappingKey(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1244,6 +1290,7 @@ func (c *Core) preSeal() error {
|
|||
c.metricsCh = nil
|
||||
}
|
||||
var result error
|
||||
|
||||
if c.ha != nil {
|
||||
c.stopClusterListener()
|
||||
}
|
||||
|
@ -1266,6 +1313,10 @@ func (c *Core) preSeal() error {
|
|||
if err := c.unloadMounts(); err != nil {
|
||||
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
|
||||
}
|
||||
if err := enterprisePreSeal(c); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// Purge the backend if supported
|
||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
|
@ -1274,6 +1325,22 @@ func (c *Core) preSeal() error {
|
|||
return result
|
||||
}
|
||||
|
||||
func enterprisePostUnsealImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enterprisePreSealImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startReplicationImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopReplicationImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runStandby is a long running routine that is used when an HA backend
|
||||
// is enabled. It waits until we are leader and switches this Vault to
|
||||
// active.
|
||||
|
|
|
@ -22,6 +22,23 @@ var (
|
|||
protectedPaths = []string{
|
||||
"core",
|
||||
}
|
||||
|
||||
replicationPaths = []*framework.Path{
|
||||
&framework.Path{
|
||||
Pattern: "replication/status",
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
var state consts.ReplicationState
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"mode": state.String(),
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backend, error) {
|
||||
|
@ -675,7 +692,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
|
|||
},
|
||||
}
|
||||
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.replicationPaths()...)
|
||||
b.Backend.Paths = append(b.Backend.Paths, replicationPaths...)
|
||||
|
||||
b.Backend.Invalidate = b.invalidate
|
||||
|
||||
|
|
115
vault/mount.go
115
vault/mount.go
|
@ -19,6 +19,10 @@ const (
|
|||
// can only be viewed or modified after an unseal.
|
||||
coreMountConfigPath = "core/mounts"
|
||||
|
||||
// coreLocalMountConfigPath is used to store mount configuration for local
|
||||
// (non-replicated) mounts
|
||||
coreLocalMountConfigPath = "core/local-mounts"
|
||||
|
||||
// backendBarrierPrefix is the prefix to the UUID used in the
|
||||
// barrier view for the backends.
|
||||
backendBarrierPrefix = "logical/"
|
||||
|
@ -124,6 +128,7 @@ type MountEntry struct {
|
|||
UUID string `json:"uuid"` // Barrier view UUID
|
||||
Config MountConfig `json:"config"` // Configuration related to this mount (but not backend-derived)
|
||||
Options map[string]string `json:"options"` // Backend options
|
||||
Local bool `json:"local"` // Local mounts are not replicated or affected by replication
|
||||
Tainted bool `json:"tainted,omitempty"` // Set as a Write-Ahead flag for unmount/remount
|
||||
}
|
||||
|
||||
|
@ -147,28 +152,29 @@ func (e *MountEntry) Clone() *MountEntry {
|
|||
UUID: e.UUID,
|
||||
Config: e.Config,
|
||||
Options: optClone,
|
||||
Local: e.Local,
|
||||
Tainted: e.Tainted,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount is used to mount a new backend to the mount table.
|
||||
func (c *Core) mount(me *MountEntry) error {
|
||||
func (c *Core) mount(entry *MountEntry) error {
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(me.Path, "/") {
|
||||
me.Path += "/"
|
||||
if !strings.HasSuffix(entry.Path, "/") {
|
||||
entry.Path += "/"
|
||||
}
|
||||
|
||||
// Prevent protected paths from being mounted
|
||||
for _, p := range protectedMounts {
|
||||
if strings.HasPrefix(me.Path, p) {
|
||||
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
|
||||
if strings.HasPrefix(entry.Path, p) {
|
||||
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", entry.Path))
|
||||
}
|
||||
}
|
||||
|
||||
// Do not allow more than one instance of a singleton mount
|
||||
for _, p := range singletonMounts {
|
||||
if me.Type == p {
|
||||
return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", me.Type))
|
||||
if entry.Type == p {
|
||||
return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", entry.Type))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,37 +182,47 @@ func (c *Core) mount(me *MountEntry) error {
|
|||
defer c.mountsLock.Unlock()
|
||||
|
||||
// Verify there is no conflicting mount
|
||||
if match := c.router.MatchingMount(me.Path); match != "" {
|
||||
if match := c.router.MatchingMount(entry.Path); match != "" {
|
||||
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
|
||||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
meUUID, err := uuid.GenerateUUID()
|
||||
if entry.UUID == "" {
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
}
|
||||
viewPath := backendBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
backend, err := c.newLogicalBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
me.UUID = meUUID
|
||||
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
|
||||
|
||||
backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil)
|
||||
if err != nil {
|
||||
// Call initialize; this takes care of init tasks that must be run after
|
||||
// the ignore paths are collected
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newTable := c.mounts.shallowClone()
|
||||
newTable.Entries = append(newTable.Entries, me)
|
||||
newTable.Entries = append(newTable.Entries, entry)
|
||||
if err := c.persistMounts(newTable); err != nil {
|
||||
c.logger.Error("core: failed to update mount table", "error", err)
|
||||
return logical.CodedError(500, "failed to update mount table")
|
||||
}
|
||||
c.mounts = newTable
|
||||
|
||||
if err := c.router.Mount(backend, me.Path, me, view); err != nil {
|
||||
if err := c.router.Mount(backend, entry.Path, entry, view); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("core: successful mount", "path", me.Path, "type", me.Type)
|
||||
c.logger.Info("core: successful mount", "path", entry.Path, "type", entry.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -291,6 +307,12 @@ func (c *Core) removeMountEntry(path string) error {
|
|||
newTable := c.mounts.shallowClone()
|
||||
newTable.remove(path)
|
||||
|
||||
// When unmounting all entries the JSON code will load back up from storage
|
||||
// as a nil slice, which kills tests...just set it nil explicitly
|
||||
if len(newTable.Entries) == 0 {
|
||||
newTable.Entries = nil
|
||||
}
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(newTable); err != nil {
|
||||
c.logger.Error("core: failed to update mount table", "error", err)
|
||||
|
@ -405,12 +427,18 @@ func (c *Core) remount(src, dst string) error {
|
|||
// loadMounts is invoked as part of postUnseal to load the mount table
|
||||
func (c *Core) loadMounts() error {
|
||||
mountTable := &MountTable{}
|
||||
localMountTable := &MountTable{}
|
||||
// Load the existing mount table
|
||||
raw, err := c.barrier.Get(coreMountConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read mount table", "error", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
rawLocal, err := c.barrier.Get(coreLocalMountConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read local mount table", "error", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
@ -425,6 +453,13 @@ func (c *Core) loadMounts() error {
|
|||
}
|
||||
c.mounts = mountTable
|
||||
}
|
||||
if rawLocal != nil {
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountTable); err != nil {
|
||||
c.logger.Error("core: failed to decompress and/or decode the local mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
c.mounts.Entries = append(c.mounts.Entries, localMountTable.Entries...)
|
||||
}
|
||||
|
||||
// Ensure that required entries are loaded, or new ones
|
||||
// added may never get loaded at all. Note that this
|
||||
|
@ -492,8 +527,24 @@ func (c *Core) persistMounts(table *MountTable) error {
|
|||
}
|
||||
}
|
||||
|
||||
nonLocalMounts := &MountTable{
|
||||
Type: mountTableType,
|
||||
}
|
||||
|
||||
localMounts := &MountTable{
|
||||
Type: mountTableType,
|
||||
}
|
||||
|
||||
for _, entry := range table.Entries {
|
||||
if entry.Local {
|
||||
localMounts.Entries = append(localMounts.Entries, entry)
|
||||
} else {
|
||||
nonLocalMounts.Entries = append(nonLocalMounts.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode the mount table into JSON and compress it (lzw).
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(table, nil)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress the mount table", "error", err)
|
||||
return err
|
||||
|
@ -510,6 +561,24 @@ func (c *Core) persistMounts(table *MountTable) error {
|
|||
c.logger.Error("core: failed to persist mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Repeat with local mounts
|
||||
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localMounts, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress the local mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
entry = &Entry{
|
||||
Key: coreLocalMountConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(entry); err != nil {
|
||||
c.logger.Error("core: failed to persist local mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -532,15 +601,19 @@ func (c *Core) setupMounts() error {
|
|||
|
||||
// Create a barrier view using the UUID
|
||||
view = NewBarrierView(c.barrier, barrierPath)
|
||||
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
// Initialize the backend
|
||||
// Create the new backend
|
||||
backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
||||
backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch entry.Type {
|
||||
case "system":
|
||||
c.systemBarrierView = view
|
||||
|
@ -616,10 +689,10 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi
|
|||
// mountEntrySysView creates a logical.SystemView from global and
|
||||
// mount-specific entries; because this should be called when setting
|
||||
// up a mountEntry, it doesn't check to ensure that me is not nil
|
||||
func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
|
||||
func (c *Core) mountEntrySysView(entry *MountEntry) logical.SystemView {
|
||||
return dynamicSystemView{
|
||||
core: c,
|
||||
mountEntry: me,
|
||||
mountEntry: entry,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/compressutil"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -82,6 +83,96 @@ func TestCore_Mount(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that the local table actually gets populated as expected with local
|
||||
// entries, and that upon reading the entries from both are recombined
|
||||
// correctly
|
||||
func TestCore_Mount_Local(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
|
||||
c.mounts = &MountTable{
|
||||
Type: mountTableType,
|
||||
Entries: []*MountEntry{
|
||||
&MountEntry{
|
||||
Table: mountTableType,
|
||||
Path: "noop/",
|
||||
Type: "generic",
|
||||
UUID: "abcd",
|
||||
},
|
||||
&MountEntry{
|
||||
Table: mountTableType,
|
||||
Path: "noop2/",
|
||||
Type: "generic",
|
||||
UUID: "bcde",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Both should set up successfully
|
||||
err := c.setupMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(c.mounts.Entries) != 2 {
|
||||
t.Fatalf("expected two entries, got %d", len(c.mounts.Entries))
|
||||
}
|
||||
|
||||
rawLocal, err := c.barrier.Get(coreLocalMountConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local mounts")
|
||||
}
|
||||
localMountsTable := &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localMountsTable.Entries) > 0 {
|
||||
t.Fatalf("expected no entries in local mount table, got %#v", localMountsTable)
|
||||
}
|
||||
|
||||
c.mounts.Entries[1].Local = true
|
||||
if err := c.persistMounts(c.mounts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err = c.barrier.Get(coreLocalMountConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local mount")
|
||||
}
|
||||
localMountsTable = &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localMountsTable.Entries) != 1 {
|
||||
t.Fatalf("expected one entry in local mount table, got %#v", localMountsTable)
|
||||
}
|
||||
|
||||
oldMounts := c.mounts
|
||||
if err := c.loadMounts(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
compEntries := c.mounts.Entries[:0]
|
||||
// Filter out required mounts
|
||||
for _, v := range c.mounts.Entries {
|
||||
if v.Type == "generic" {
|
||||
compEntries = append(compEntries, v)
|
||||
}
|
||||
}
|
||||
c.mounts.Entries = compEntries
|
||||
|
||||
if !reflect.DeepEqual(oldMounts, c.mounts) {
|
||||
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldMounts, c.mounts)
|
||||
}
|
||||
|
||||
if len(c.mounts.Entries) != 2 {
|
||||
t.Fatalf("expected two mount entries, got %#v", localMountsTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_Unmount(t *testing.T) {
|
||||
c, keys, _ := TestCoreUnsealed(t)
|
||||
existed, err := c.unmount("secret")
|
||||
|
|
|
@ -2,11 +2,13 @@ package vault
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
@ -137,7 +139,13 @@ func (c *Core) setupPolicyStore() error {
|
|||
view := c.systemBarrierView.SubView(policySubPath)
|
||||
|
||||
// Create the policy store
|
||||
c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c})
|
||||
sysView := &dynamicSystemView{core: c}
|
||||
c.policyStore = NewPolicyStore(view, sysView)
|
||||
|
||||
if sysView.ReplicationState() == consts.ReplicationSecondary {
|
||||
// Policies will sync from the primary
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure that the default policy exists, and if not, create it
|
||||
policy, err := c.policyStore.GetPolicy("default")
|
||||
|
@ -173,6 +181,16 @@ func (c *Core) teardownPolicyStore() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *PolicyStore) invalidate(name string) {
|
||||
if ps.lru == nil {
|
||||
// Nothing to do if the cache is not used
|
||||
return
|
||||
}
|
||||
|
||||
// This may come with a prefixed "/" due to joining the file path
|
||||
ps.lru.Remove(strings.TrimPrefix(name, "/"))
|
||||
}
|
||||
|
||||
// SetPolicy is used to create or update the given policy
|
||||
func (ps *PolicyStore) SetPolicy(p *Policy) error {
|
||||
defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())
|
||||
|
|
Loading…
Reference in New Issue