Port some updates

This commit is contained in:
Jeff Mitchell 2017-01-06 15:42:18 -05:00
parent 2fd5ab5f10
commit 9e5d1eaac9
14 changed files with 125 additions and 110 deletions

View File

@ -2,6 +2,7 @@ package logical
import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/jsonutil"
)
@ -37,3 +38,67 @@ func StorageEntryJSON(k string, v interface{}) (*StorageEntry, error) {
Value: encodedBytes,
}, nil
}
type ClearableView interface {
List(string) ([]string, error)
Delete(string) error
}
// ScanView is used to scan all the keys in a view iteratively
func ScanView(view ClearableView, cb func(path string)) error {
frontier := []string{""}
for len(frontier) > 0 {
n := len(frontier)
current := frontier[n-1]
frontier = frontier[:n-1]
// List the contents
contents, err := view.List(current)
if err != nil {
return fmt.Errorf("list failed at path '%s': %v", current, err)
}
// Handle the contents in the directory
for _, c := range contents {
fullPath := current + c
if strings.HasSuffix(c, "/") {
frontier = append(frontier, fullPath)
} else {
cb(fullPath)
}
}
}
return nil
}
// CollectKeys is used to collect all the keys in a view
func CollectKeys(view ClearableView) ([]string, error) {
// Accumulate the keys
var existing []string
cb := func(path string) {
existing = append(existing, path)
}
// Scan for all the keys
if err := ScanView(view, cb); err != nil {
return nil, err
}
return existing, nil
}
// ClearView is used to delete all the keys in a view
func ClearView(view ClearableView) error {
// Collect all the keys
keys, err := CollectKeys(view)
if err != nil {
return err
}
// Delete all the keys
for _, key := range keys {
if err := view.Delete(key); err != nil {
return err
}
}
return nil
}

View File

@ -45,6 +45,12 @@ type HABackend interface {
HAEnabled() bool
}
// Purgable is an optional interface for backends that support
// purging of their caches.
type Purgable interface {
Purge()
}
// RedirectDetect is an optional interface that an HABackend
// can implement. If they do, a redirect address can be automatically
// detected.

View File

@ -145,7 +145,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
// Clear the data in the view
if view != nil {
if err := ClearView(view); err != nil {
if err := logical.ClearView(view); err != nil {
return true, err
}
}

View File

@ -247,7 +247,7 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) {
}
// View should be empty
out, err := CollectKeys(view)
out, err := logical.CollectKeys(view)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -101,62 +101,3 @@ func (v *BarrierView) expandKey(suffix string) string {
func (v *BarrierView) truncateKey(full string) string {
return strings.TrimPrefix(full, v.prefix)
}
// ScanView is used to scan all the keys in a view iteratively
func ScanView(view *BarrierView, cb func(path string)) error {
frontier := []string{""}
for len(frontier) > 0 {
n := len(frontier)
current := frontier[n-1]
frontier = frontier[:n-1]
// List the contents
contents, err := view.List(current)
if err != nil {
return fmt.Errorf("list failed at path '%s': %v", current, err)
}
// Handle the contents in the directory
for _, c := range contents {
fullPath := current + c
if strings.HasSuffix(c, "/") {
frontier = append(frontier, fullPath)
} else {
cb(fullPath)
}
}
}
return nil
}
// CollectKeys is used to collect all the keys in a view
func CollectKeys(view *BarrierView) ([]string, error) {
// Accumulate the keys
var existing []string
cb := func(path string) {
existing = append(existing, path)
}
// Scan for all the keys
if err := ScanView(view, cb); err != nil {
return nil, err
}
return existing, nil
}
// ClearView is used to delete all the keys in a view
func ClearView(view *BarrierView) error {
// Collect all the keys
keys, err := CollectKeys(view)
if err != nil {
return err
}
// Delete all the keys
for _, key := range keys {
if err := view.Delete(key); err != nil {
return err
}
}
return nil
}

View File

@ -202,7 +202,7 @@ func TestBarrierView_Scan(t *testing.T) {
}
// Collect the keys
if err := ScanView(view, cb); err != nil {
if err := logical.ScanView(view, cb); err != nil {
t.Fatalf("err: %v", err)
}
@ -235,7 +235,7 @@ func TestBarrierView_CollectKeys(t *testing.T) {
}
// Collect the keys
out, err := CollectKeys(view)
out, err := logical.CollectKeys(view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -269,12 +269,12 @@ func TestBarrierView_ClearView(t *testing.T) {
}
// Clear the keys
if err := ClearView(view); err != nil {
if err := logical.ClearView(view); err != nil {
t.Fatalf("err: %v", err)
}
// Collect the keys
out, err := CollectKeys(view)
out, err := logical.CollectKeys(view)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -20,6 +20,7 @@ import (
"golang.org/x/net/http2"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/helper/jsonutil"
@ -43,9 +44,9 @@ var (
// This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct {
Type string `json:"type"`
X *big.Int `json:"x,omitempty"`
Y *big.Int `json:"y,omitempty"`
D *big.Int `json:"d,omitempty"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
}
type activeConnection struct {
@ -62,8 +63,8 @@ type Cluster struct {
ID string `json:"id" structs:"id" mapstructure:"id"`
}
// Cluster fetches the details of either local or global cluster based on the
// input. This method errors out when Vault is sealed.
// Cluster fetches the details of the local cluster. This method errors out
// when Vault is sealed.
func (c *Core) Cluster() (*Cluster, error) {
var cluster Cluster
@ -91,7 +92,7 @@ func (c *Core) Cluster() (*Cluster, error) {
// This sets our local cluster cert and private key based on the advertisement.
// It also ensures the cert is in our local cluster cert pool.
func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) error {
switch {
case adv.ClusterAddr == "":
// Clustering disabled on the server, don't try to look for params
@ -136,7 +137,7 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
return fmt.Errorf("error parsing local cluster certificate: %v", err)
}
c.localClusterCertPool.AddCert(cert)
c.clusterCertPool.AddCert(cert)
return nil
}
@ -145,6 +146,10 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
// Entries will be created only if they are not already present. If clusterName
// is not supplied, this method will auto-generate it.
func (c *Core) setupCluster() error {
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
// Check if storage index is already present or not
cluster, err := c.Cluster()
if err != nil {
@ -194,10 +199,6 @@ func (c *Core) setupCluster() error {
// If we're using HA, generate server-to-server parameters
if c.ha != nil {
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
// Create a private key
{
c.logger.Trace("core: generating cluster private key")
@ -240,13 +241,13 @@ func (c *Core) setupCluster() error {
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
if err != nil {
c.logger.Error("core: error generating self-signed cert", "error", err)
return fmt.Errorf("unable to generate local cluster certificate: %v", err)
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
}
_, err = x509.ParseCertificate(certBytes)
if err != nil {
c.logger.Error("core: error parsing self-signed cert", "error", err)
return fmt.Errorf("error parsing generated certificate: %v", err)
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
}
c.localClusterCert = certBytes
@ -363,7 +364,7 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
}
// This is idempotent, so be sure it's been added
c.localClusterCertPool.AddCert(parsedCert)
c.clusterCertPool.AddCert(parsedCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
@ -372,10 +373,10 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
PrivateKey: c.localClusterPrivateKey,
},
},
RootCAs: c.localClusterCertPool,
RootCAs: c.clusterCertPool,
ServerName: parsedCert.Subject.CommonName,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.localClusterCertPool,
ClientCAs: c.clusterCertPool,
}
return tlsConfig, nil

View File

@ -267,8 +267,8 @@ type Core struct {
localClusterPrivateKey crypto.Signer
// The local cluster cert
localClusterCert []byte
// The cert pool containing the self-signed CA as a trusted CA
localClusterCertPool *x509.CertPool
// The cert pool containing trusted cluster CAs
clusterCertPool *x509.CertPool
// The TCP addresses we should use for clustering
clusterListenerAddrs []*net.TCPAddr
// The setup function that gives us the handler to use
@ -378,16 +378,6 @@ func NewCore(conf *CoreConfig) (*Core, error) {
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
}
// Wrap the backend in a cache unless disabled
if !conf.DisableCache {
_, isCache := conf.Physical.(*physical.Cache)
_, isInmem := conf.Physical.(*physical.InmemBackend)
if !isCache && !isInmem {
cache := physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
conf.Physical = cache
}
}
if !conf.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
@ -425,11 +415,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
localClusterCertPool: x509.NewCertPool(),
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the backend in a cache unless disabled
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical
}
@ -714,7 +709,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
if !oldAdv {
// Ensure we are using current values
err = c.loadClusterTLS(adv)
err = c.loadLocalClusterTLS(adv)
if err != nil {
return false, "", err
}
@ -1134,8 +1129,10 @@ func (c *Core) postUnseal() (retErr error) {
}
}()
c.logger.Info("core: post-unseal setup starting")
if cache, ok := c.physical.(*physical.Cache); ok {
cache.Purge()
// Purge the backend if supported
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
}
// HA mode requires us to handle keyring rotation and rekeying
if c.ha != nil {
@ -1232,8 +1229,9 @@ func (c *Core) preSeal() error {
if err := c.unloadMounts(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
}
if cache, ok := c.physical.(*physical.Cache); ok {
cache.Purge()
// Purge the backend if supported
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
}
c.logger.Info("core: pre-seal teardown complete")
return result

View File

@ -119,7 +119,7 @@ func (m *ExpirationManager) Restore() error {
defer m.pendingLock.Unlock()
// Accumulate existing leases
existing, err := CollectKeys(m.idView)
existing, err := logical.CollectKeys(m.idView)
if err != nil {
return fmt.Errorf("failed to scan for leases: %v", err)
}
@ -292,7 +292,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error
// Accumulate existing leases
sub := m.idView.SubView(prefix)
existing, err := CollectKeys(sub)
existing, err := logical.CollectKeys(sub)
if err != nil {
return fmt.Errorf("failed to scan for leases: %v", err)
}

View File

@ -60,7 +60,7 @@ func (b *CubbyholeBackend) revoke(saltedToken string) error {
return fmt.Errorf("cubbyhole: client token empty during revocation")
}
if err := ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil {
if err := logical.ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil {
return err
}

View File

@ -261,7 +261,7 @@ func (c *Core) unmount(path string) (bool, error) {
}
// Clear the data in the view
if err := ClearView(view); err != nil {
if err := logical.ClearView(view); err != nil {
return true, err
}

View File

@ -187,7 +187,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
}
// View should be empty
out, err := CollectKeys(view)
out, err := logical.CollectKeys(view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -305,7 +305,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
}
// View should not be empty
out, err := CollectKeys(view)
out, err := logical.CollectKeys(view)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -275,7 +275,7 @@ func (ps *PolicyStore) ListPolicies() ([]string, error) {
defer metrics.MeasureSince([]string{"policy", "list_policies"}, time.Now())
// Scan the view, since the policy names are the same as the
// key names.
keys, err := CollectKeys(ps.view)
keys, err := logical.CollectKeys(ps.view)
for _, nonAssignable := range nonAssignablePolicies {
deleteIndex := -1

View File

@ -225,7 +225,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
// the TLS state.
ctx, cancelFunc := context.WithCancel(context.Background())
c.rpcClientConnCancelFunc = cancelFunc
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer()), grpc.WithInsecure())
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure())
if err != nil {
c.logger.Error("core/refreshRequestForwardingConnection: err setting up rpc client", "error", err)
return err
@ -330,14 +330,18 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer() func(string, time.Duration) (net.Conn, error) {
func (c *Core) getGRPCDialer(alpnProto, serverName string) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
tlsConfig, err := c.ClusterTLSConfig()
if err != nil {
c.logger.Error("core/getGRPCDialer: failed to get tls configuration", "error", err)
c.logger.Error("core: failed to get tls configuration", "error", err)
return nil, err
}
tlsConfig.NextProtos = []string{"req_fw_sb-act_v1"}
if serverName != "" {
tlsConfig.ServerName = serverName
}
tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}