From 9e5d1eaac9d83c775e3f578a40e918556162f0ce Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 6 Jan 2017 15:42:18 -0500 Subject: [PATCH] Port some updates --- logical/storage.go | 65 +++++++++++++++++++++++++++++++++++++ physical/physical.go | 6 ++++ vault/auth.go | 2 +- vault/auth_test.go | 2 +- vault/barrier_view.go | 59 --------------------------------- vault/barrier_view_test.go | 8 ++--- vault/cluster.go | 33 ++++++++++--------- vault/core.go | 34 +++++++++---------- vault/expiration.go | 4 +-- vault/logical_cubbyhole.go | 2 +- vault/mount.go | 2 +- vault/mount_test.go | 4 +-- vault/policy_store.go | 2 +- vault/request_forwarding.go | 12 ++++--- 14 files changed, 125 insertions(+), 110 deletions(-) diff --git a/logical/storage.go b/logical/storage.go index f7f4d1a64..8c15f4bf6 100644 --- a/logical/storage.go +++ b/logical/storage.go @@ -2,6 +2,7 @@ package logical import ( "fmt" + "strings" "github.com/hashicorp/vault/helper/jsonutil" ) @@ -37,3 +38,67 @@ func StorageEntryJSON(k string, v interface{}) (*StorageEntry, error) { Value: encodedBytes, }, nil } + +type ClearableView interface { + List(string) ([]string, error) + Delete(string) error +} + +// ScanView is used to scan all the keys in a view iteratively +func ScanView(view ClearableView, cb func(path string)) error { + frontier := []string{""} + for len(frontier) > 0 { + n := len(frontier) + current := frontier[n-1] + frontier = frontier[:n-1] + + // List the contents + contents, err := view.List(current) + if err != nil { + return fmt.Errorf("list failed at path '%s': %v", current, err) + } + + // Handle the contents in the directory + for _, c := range contents { + fullPath := current + c + if strings.HasSuffix(c, "/") { + frontier = append(frontier, fullPath) + } else { + cb(fullPath) + } + } + } + return nil +} + +// CollectKeys is used to collect all the keys in a view +func CollectKeys(view ClearableView) ([]string, error) { + // Accumulate the keys + var existing []string + cb := func(path string) { + existing = append(existing, path) + } + + // Scan for all the keys + if err := ScanView(view, cb); err != nil { + return nil, err + } + return existing, nil +} + +// ClearView is used to delete all the keys in a view +func ClearView(view ClearableView) error { + // Collect all the keys + keys, err := CollectKeys(view) + if err != nil { + return err + } + + // Delete all the keys + for _, key := range keys { + if err := view.Delete(key); err != nil { + return err + } + } + return nil +} diff --git a/physical/physical.go b/physical/physical.go index d63244336..9427820bc 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -45,6 +45,12 @@ type HABackend interface { HAEnabled() bool } +// Purgable is an optional interface for backends that support +// purging of their caches. +type Purgable interface { + Purge() +} + // RedirectDetect is an optional interface that an HABackend // can implement. If they do, a redirect address can be automatically // detected. diff --git a/vault/auth.go b/vault/auth.go index 9acb49f56..c21b996f2 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -145,7 +145,7 @@ func (c *Core) disableCredential(path string) (bool, error) { // Clear the data in the view if view != nil { - if err := ClearView(view); err != nil { + if err := logical.ClearView(view); err != nil { return true, err } } diff --git a/vault/auth_test.go b/vault/auth_test.go index 7bff2e335..b7abdfa1a 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -247,7 +247,7 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) { } // View should be empty - out, err := CollectKeys(view) + out, err := logical.CollectKeys(view) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/barrier_view.go b/vault/barrier_view.go index 772c3b96f..67964c646 100644 --- a/vault/barrier_view.go +++ b/vault/barrier_view.go @@ -101,62 +101,3 @@ func (v *BarrierView) expandKey(suffix string) string { func (v *BarrierView) truncateKey(full string) string { return strings.TrimPrefix(full, v.prefix) } - -// ScanView is used to scan all the keys in a view iteratively -func ScanView(view *BarrierView, cb func(path string)) error { - frontier := []string{""} - for len(frontier) > 0 { - n := len(frontier) - current := frontier[n-1] - frontier = frontier[:n-1] - - // List the contents - contents, err := view.List(current) - if err != nil { - return fmt.Errorf("list failed at path '%s': %v", current, err) - } - - // Handle the contents in the directory - for _, c := range contents { - fullPath := current + c - if strings.HasSuffix(c, "/") { - frontier = append(frontier, fullPath) - } else { - cb(fullPath) - } - } - } - return nil -} - -// CollectKeys is used to collect all the keys in a view -func CollectKeys(view *BarrierView) ([]string, error) { - // Accumulate the keys - var existing []string - cb := func(path string) { - existing = append(existing, path) - } - - // Scan for all the keys - if err := ScanView(view, cb); err != nil { - return nil, err - } - return existing, nil -} - -// ClearView is used to delete all the keys in a view -func ClearView(view *BarrierView) error { - // Collect all the keys - keys, err := CollectKeys(view) - if err != nil { - return err - } - - // Delete all the keys - for _, key := range keys { - if err := view.Delete(key); err != nil { - return err - } - } - return nil -} diff --git a/vault/barrier_view_test.go b/vault/barrier_view_test.go index d80aaeed4..6b80c54ce 100644 --- a/vault/barrier_view_test.go +++ b/vault/barrier_view_test.go @@ -202,7 +202,7 @@ func TestBarrierView_Scan(t *testing.T) { } // Collect the keys - if err := ScanView(view, cb); err != nil { + if err := logical.ScanView(view, cb); err != nil { t.Fatalf("err: %v", err) } @@ -235,7 +235,7 @@ func TestBarrierView_CollectKeys(t *testing.T) { } // Collect the keys - out, err := CollectKeys(view) + out, err := logical.CollectKeys(view) if err != nil { t.Fatalf("err: %v", err) } @@ -269,12 +269,12 @@ func TestBarrierView_ClearView(t *testing.T) { } // Clear the keys - if err := ClearView(view); err != nil { + if err := logical.ClearView(view); err != nil { t.Fatalf("err: %v", err) } // Collect the keys - out, err := CollectKeys(view) + out, err := logical.CollectKeys(view) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/cluster.go b/vault/cluster.go index d2dd63704..d146f98fe 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -20,6 +20,7 @@ import ( "golang.org/x/net/http2" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/forwarding" "github.com/hashicorp/vault/helper/jsonutil" @@ -43,9 +44,9 @@ var ( // This can be one of a few key types so the different params may or may not be filled type clusterKeyParams struct { Type string `json:"type"` - X *big.Int `json:"x,omitempty"` - Y *big.Int `json:"y,omitempty"` - D *big.Int `json:"d,omitempty"` + X *big.Int `json:"x" structs:"x" mapstructure:"x"` + Y *big.Int `json:"y" structs:"y" mapstructure:"y"` + D *big.Int `json:"d" structs:"d" mapstructure:"d"` } type activeConnection struct { @@ -62,8 +63,8 @@ type Cluster struct { ID string `json:"id" structs:"id" mapstructure:"id"` } -// Cluster fetches the details of either local or global cluster based on the -// input. This method errors out when Vault is sealed. +// Cluster fetches the details of the local cluster. This method errors out +// when Vault is sealed. func (c *Core) Cluster() (*Cluster, error) { var cluster Cluster @@ -91,7 +92,7 @@ func (c *Core) Cluster() (*Cluster, error) { // This sets our local cluster cert and private key based on the advertisement. // It also ensures the cert is in our local cluster cert pool. -func (c *Core) loadClusterTLS(adv activeAdvertisement) error { +func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) error { switch { case adv.ClusterAddr == "": // Clustering disabled on the server, don't try to look for params @@ -136,7 +137,7 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error { return fmt.Errorf("error parsing local cluster certificate: %v", err) } - c.localClusterCertPool.AddCert(cert) + c.clusterCertPool.AddCert(cert) return nil } @@ -145,6 +146,10 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error { // Entries will be created only if they are not already present. If clusterName // is not supplied, this method will auto-generate it. func (c *Core) setupCluster() error { + // Prevent data races with the TLS parameters + c.clusterParamsLock.Lock() + defer c.clusterParamsLock.Unlock() + // Check if storage index is already present or not cluster, err := c.Cluster() if err != nil { @@ -194,10 +199,6 @@ func (c *Core) setupCluster() error { // If we're using HA, generate server-to-server parameters if c.ha != nil { - // Prevent data races with the TLS parameters - c.clusterParamsLock.Lock() - defer c.clusterParamsLock.Unlock() - // Create a private key { c.logger.Trace("core: generating cluster private key") @@ -240,13 +241,13 @@ func (c *Core) setupCluster() error { certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey) if err != nil { c.logger.Error("core: error generating self-signed cert", "error", err) - return fmt.Errorf("unable to generate local cluster certificate: %v", err) + return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err) } _, err = x509.ParseCertificate(certBytes) if err != nil { c.logger.Error("core: error parsing self-signed cert", "error", err) - return fmt.Errorf("error parsing generated certificate: %v", err) + return errwrap.Wrapf("error parsing generated certificate: {{err}}", err) } c.localClusterCert = certBytes @@ -363,7 +364,7 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) { } // This is idempotent, so be sure it's been added - c.localClusterCertPool.AddCert(parsedCert) + c.clusterCertPool.AddCert(parsedCert) tlsConfig := &tls.Config{ Certificates: []tls.Certificate{ @@ -372,10 +373,10 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) { PrivateKey: c.localClusterPrivateKey, }, }, - RootCAs: c.localClusterCertPool, + RootCAs: c.clusterCertPool, ServerName: parsedCert.Subject.CommonName, ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: c.localClusterCertPool, + ClientCAs: c.clusterCertPool, } return tlsConfig, nil diff --git a/vault/core.go b/vault/core.go index 3bd2b9934..aea2a3337 100644 --- a/vault/core.go +++ b/vault/core.go @@ -267,8 +267,8 @@ type Core struct { localClusterPrivateKey crypto.Signer // The local cluster cert localClusterCert []byte - // The cert pool containing the self-signed CA as a trusted CA - localClusterCertPool *x509.CertPool + // The cert pool containing trusted cluster CAs + clusterCertPool *x509.CertPool // The TCP addresses we should use for clustering clusterListenerAddrs []*net.TCPAddr // The setup function that gives us the handler to use @@ -378,16 +378,6 @@ func NewCore(conf *CoreConfig) (*Core, error) { conf.Logger = logformat.NewVaultLogger(log.LevelTrace) } - // Wrap the backend in a cache unless disabled - if !conf.DisableCache { - _, isCache := conf.Physical.(*physical.Cache) - _, isInmem := conf.Physical.(*physical.InmemBackend) - if !isCache && !isInmem { - cache := physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger) - conf.Physical = cache - } - } - if !conf.DisableMlock { // Ensure our memory usage is locked into physical RAM if err := mlock.LockMemory(); err != nil { @@ -425,11 +415,16 @@ func NewCore(conf *CoreConfig) (*Core, error) { maxLeaseTTL: conf.MaxLeaseTTL, cachingDisabled: conf.DisableCache, clusterName: conf.ClusterName, - localClusterCertPool: x509.NewCertPool(), + clusterCertPool: x509.NewCertPool(), clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), } + // Wrap the backend in a cache unless disabled + if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache { + c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger) + } + if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { c.ha = conf.HAPhysical } @@ -714,7 +709,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { if !oldAdv { // Ensure we are using current values - err = c.loadClusterTLS(adv) + err = c.loadLocalClusterTLS(adv) if err != nil { return false, "", err } @@ -1134,8 +1129,10 @@ func (c *Core) postUnseal() (retErr error) { } }() c.logger.Info("core: post-unseal setup starting") - if cache, ok := c.physical.(*physical.Cache); ok { - cache.Purge() + + // Purge the backend if supported + if purgable, ok := c.physical.(physical.Purgable); ok { + purgable.Purge() } // HA mode requires us to handle keyring rotation and rekeying if c.ha != nil { @@ -1232,8 +1229,9 @@ func (c *Core) preSeal() error { if err := c.unloadMounts(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err)) } - if cache, ok := c.physical.(*physical.Cache); ok { - cache.Purge() + // Purge the backend if supported + if purgable, ok := c.physical.(physical.Purgable); ok { + purgable.Purge() } c.logger.Info("core: pre-seal teardown complete") return result diff --git a/vault/expiration.go b/vault/expiration.go index 51a7ff422..4328d0cd4 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -119,7 +119,7 @@ func (m *ExpirationManager) Restore() error { defer m.pendingLock.Unlock() // Accumulate existing leases - existing, err := CollectKeys(m.idView) + existing, err := logical.CollectKeys(m.idView) if err != nil { return fmt.Errorf("failed to scan for leases: %v", err) } @@ -292,7 +292,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error // Accumulate existing leases sub := m.idView.SubView(prefix) - existing, err := CollectKeys(sub) + existing, err := logical.CollectKeys(sub) if err != nil { return fmt.Errorf("failed to scan for leases: %v", err) } diff --git a/vault/logical_cubbyhole.go b/vault/logical_cubbyhole.go index f9dcd5429..76353b0be 100644 --- a/vault/logical_cubbyhole.go +++ b/vault/logical_cubbyhole.go @@ -60,7 +60,7 @@ func (b *CubbyholeBackend) revoke(saltedToken string) error { return fmt.Errorf("cubbyhole: client token empty during revocation") } - if err := ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil { + if err := logical.ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil { return err } diff --git a/vault/mount.go b/vault/mount.go index c19366c1e..6fcbf1ca9 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -261,7 +261,7 @@ func (c *Core) unmount(path string) (bool, error) { } // Clear the data in the view - if err := ClearView(view); err != nil { + if err := logical.ClearView(view); err != nil { return true, err } diff --git a/vault/mount_test.go b/vault/mount_test.go index e1b6bfc9f..268096a91 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -187,7 +187,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) { } // View should be empty - out, err := CollectKeys(view) + out, err := logical.CollectKeys(view) if err != nil { t.Fatalf("err: %v", err) } @@ -305,7 +305,7 @@ func TestCore_Remount_Cleanup(t *testing.T) { } // View should not be empty - out, err := CollectKeys(view) + out, err := logical.CollectKeys(view) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/policy_store.go b/vault/policy_store.go index 46cfdf708..873a8dde8 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -275,7 +275,7 @@ func (ps *PolicyStore) ListPolicies() ([]string, error) { defer metrics.MeasureSince([]string{"policy", "list_policies"}, time.Now()) // Scan the view, since the policy names are the same as the // key names. - keys, err := CollectKeys(ps.view) + keys, err := logical.CollectKeys(ps.view) for _, nonAssignable := range nonAssignablePolicies { deleteIndex := -1 diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 128fedd00..def2c1021 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -225,7 +225,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error { // the TLS state. ctx, cancelFunc := context.WithCancel(context.Background()) c.rpcClientConnCancelFunc = cancelFunc - c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer()), grpc.WithInsecure()) + c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure()) if err != nil { c.logger.Error("core/refreshRequestForwardingConnection: err setting up rpc client", "error", err) return err @@ -330,14 +330,18 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro // getGRPCDialer is used to return a dialer that has the correct TLS // configuration. Otherwise gRPC tries to be helpful and stomps all over our // NextProtos. -func (c *Core) getGRPCDialer() func(string, time.Duration) (net.Conn, error) { +func (c *Core) getGRPCDialer(alpnProto, serverName string) func(string, time.Duration) (net.Conn, error) { return func(addr string, timeout time.Duration) (net.Conn, error) { tlsConfig, err := c.ClusterTLSConfig() if err != nil { - c.logger.Error("core/getGRPCDialer: failed to get tls configuration", "error", err) + c.logger.Error("core: failed to get tls configuration", "error", err) return nil, err } - tlsConfig.NextProtos = []string{"req_fw_sb-act_v1"} + if serverName != "" { + tlsConfig.ServerName = serverName + } + + tlsConfig.NextProtos = []string{alpnProto} dialer := &net.Dialer{ Timeout: timeout, }