Port some updates
This commit is contained in:
parent
2fd5ab5f10
commit
9e5d1eaac9
|
@ -2,6 +2,7 @@ package logical
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
)
|
||||
|
@ -37,3 +38,67 @@ func StorageEntryJSON(k string, v interface{}) (*StorageEntry, error) {
|
|||
Value: encodedBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ClearableView interface {
|
||||
List(string) ([]string, error)
|
||||
Delete(string) error
|
||||
}
|
||||
|
||||
// ScanView is used to scan all the keys in a view iteratively
|
||||
func ScanView(view ClearableView, cb func(path string)) error {
|
||||
frontier := []string{""}
|
||||
for len(frontier) > 0 {
|
||||
n := len(frontier)
|
||||
current := frontier[n-1]
|
||||
frontier = frontier[:n-1]
|
||||
|
||||
// List the contents
|
||||
contents, err := view.List(current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list failed at path '%s': %v", current, err)
|
||||
}
|
||||
|
||||
// Handle the contents in the directory
|
||||
for _, c := range contents {
|
||||
fullPath := current + c
|
||||
if strings.HasSuffix(c, "/") {
|
||||
frontier = append(frontier, fullPath)
|
||||
} else {
|
||||
cb(fullPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectKeys is used to collect all the keys in a view
|
||||
func CollectKeys(view ClearableView) ([]string, error) {
|
||||
// Accumulate the keys
|
||||
var existing []string
|
||||
cb := func(path string) {
|
||||
existing = append(existing, path)
|
||||
}
|
||||
|
||||
// Scan for all the keys
|
||||
if err := ScanView(view, cb); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// ClearView is used to delete all the keys in a view
|
||||
func ClearView(view ClearableView) error {
|
||||
// Collect all the keys
|
||||
keys, err := CollectKeys(view)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete all the keys
|
||||
for _, key := range keys {
|
||||
if err := view.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -45,6 +45,12 @@ type HABackend interface {
|
|||
HAEnabled() bool
|
||||
}
|
||||
|
||||
// Purgable is an optional interface for backends that support
|
||||
// purging of their caches.
|
||||
type Purgable interface {
|
||||
Purge()
|
||||
}
|
||||
|
||||
// RedirectDetect is an optional interface that an HABackend
|
||||
// can implement. If they do, a redirect address can be automatically
|
||||
// detected.
|
||||
|
|
|
@ -145,7 +145,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
|
|||
|
||||
// Clear the data in the view
|
||||
if view != nil {
|
||||
if err := ClearView(view); err != nil {
|
||||
if err := logical.ClearView(view); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -247,7 +247,7 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) {
|
|||
}
|
||||
|
||||
// View should be empty
|
||||
out, err := CollectKeys(view)
|
||||
out, err := logical.CollectKeys(view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -101,62 +101,3 @@ func (v *BarrierView) expandKey(suffix string) string {
|
|||
func (v *BarrierView) truncateKey(full string) string {
|
||||
return strings.TrimPrefix(full, v.prefix)
|
||||
}
|
||||
|
||||
// ScanView is used to scan all the keys in a view iteratively
|
||||
func ScanView(view *BarrierView, cb func(path string)) error {
|
||||
frontier := []string{""}
|
||||
for len(frontier) > 0 {
|
||||
n := len(frontier)
|
||||
current := frontier[n-1]
|
||||
frontier = frontier[:n-1]
|
||||
|
||||
// List the contents
|
||||
contents, err := view.List(current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list failed at path '%s': %v", current, err)
|
||||
}
|
||||
|
||||
// Handle the contents in the directory
|
||||
for _, c := range contents {
|
||||
fullPath := current + c
|
||||
if strings.HasSuffix(c, "/") {
|
||||
frontier = append(frontier, fullPath)
|
||||
} else {
|
||||
cb(fullPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectKeys is used to collect all the keys in a view
|
||||
func CollectKeys(view *BarrierView) ([]string, error) {
|
||||
// Accumulate the keys
|
||||
var existing []string
|
||||
cb := func(path string) {
|
||||
existing = append(existing, path)
|
||||
}
|
||||
|
||||
// Scan for all the keys
|
||||
if err := ScanView(view, cb); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// ClearView is used to delete all the keys in a view
|
||||
func ClearView(view *BarrierView) error {
|
||||
// Collect all the keys
|
||||
keys, err := CollectKeys(view)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete all the keys
|
||||
for _, key := range keys {
|
||||
if err := view.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -202,7 +202,7 @@ func TestBarrierView_Scan(t *testing.T) {
|
|||
}
|
||||
|
||||
// Collect the keys
|
||||
if err := ScanView(view, cb); err != nil {
|
||||
if err := logical.ScanView(view, cb); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -235,7 +235,7 @@ func TestBarrierView_CollectKeys(t *testing.T) {
|
|||
}
|
||||
|
||||
// Collect the keys
|
||||
out, err := CollectKeys(view)
|
||||
out, err := logical.CollectKeys(view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -269,12 +269,12 @@ func TestBarrierView_ClearView(t *testing.T) {
|
|||
}
|
||||
|
||||
// Clear the keys
|
||||
if err := ClearView(view); err != nil {
|
||||
if err := logical.ClearView(view); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Collect the keys
|
||||
out, err := CollectKeys(view)
|
||||
out, err := logical.CollectKeys(view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/forwarding"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
|
@ -43,9 +44,9 @@ var (
|
|||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type"`
|
||||
X *big.Int `json:"x,omitempty"`
|
||||
Y *big.Int `json:"y,omitempty"`
|
||||
D *big.Int `json:"d,omitempty"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
}
|
||||
|
||||
type activeConnection struct {
|
||||
|
@ -62,8 +63,8 @@ type Cluster struct {
|
|||
ID string `json:"id" structs:"id" mapstructure:"id"`
|
||||
}
|
||||
|
||||
// Cluster fetches the details of either local or global cluster based on the
|
||||
// input. This method errors out when Vault is sealed.
|
||||
// Cluster fetches the details of the local cluster. This method errors out
|
||||
// when Vault is sealed.
|
||||
func (c *Core) Cluster() (*Cluster, error) {
|
||||
var cluster Cluster
|
||||
|
||||
|
@ -91,7 +92,7 @@ func (c *Core) Cluster() (*Cluster, error) {
|
|||
|
||||
// This sets our local cluster cert and private key based on the advertisement.
|
||||
// It also ensures the cert is in our local cluster cert pool.
|
||||
func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
|
||||
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) error {
|
||||
switch {
|
||||
case adv.ClusterAddr == "":
|
||||
// Clustering disabled on the server, don't try to look for params
|
||||
|
@ -136,7 +137,7 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
|
|||
return fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
c.localClusterCertPool.AddCert(cert)
|
||||
c.clusterCertPool.AddCert(cert)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -145,6 +146,10 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
|
|||
// Entries will be created only if they are not already present. If clusterName
|
||||
// is not supplied, this method will auto-generate it.
|
||||
func (c *Core) setupCluster() error {
|
||||
// Prevent data races with the TLS parameters
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
// Check if storage index is already present or not
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
|
@ -194,10 +199,6 @@ func (c *Core) setupCluster() error {
|
|||
|
||||
// If we're using HA, generate server-to-server parameters
|
||||
if c.ha != nil {
|
||||
// Prevent data races with the TLS parameters
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
// Create a private key
|
||||
{
|
||||
c.logger.Trace("core: generating cluster private key")
|
||||
|
@ -240,13 +241,13 @@ func (c *Core) setupCluster() error {
|
|||
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
|
||||
if err != nil {
|
||||
c.logger.Error("core: error generating self-signed cert", "error", err)
|
||||
return fmt.Errorf("unable to generate local cluster certificate: %v", err)
|
||||
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
|
||||
}
|
||||
|
||||
_, err = x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
c.logger.Error("core: error parsing self-signed cert", "error", err)
|
||||
return fmt.Errorf("error parsing generated certificate: %v", err)
|
||||
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
|
||||
}
|
||||
|
||||
c.localClusterCert = certBytes
|
||||
|
@ -363,7 +364,7 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
|||
}
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.localClusterCertPool.AddCert(parsedCert)
|
||||
c.clusterCertPool.AddCert(parsedCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
|
@ -372,10 +373,10 @@ func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
|||
PrivateKey: c.localClusterPrivateKey,
|
||||
},
|
||||
},
|
||||
RootCAs: c.localClusterCertPool,
|
||||
RootCAs: c.clusterCertPool,
|
||||
ServerName: parsedCert.Subject.CommonName,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.localClusterCertPool,
|
||||
ClientCAs: c.clusterCertPool,
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
|
|
|
@ -267,8 +267,8 @@ type Core struct {
|
|||
localClusterPrivateKey crypto.Signer
|
||||
// The local cluster cert
|
||||
localClusterCert []byte
|
||||
// The cert pool containing the self-signed CA as a trusted CA
|
||||
localClusterCertPool *x509.CertPool
|
||||
// The cert pool containing trusted cluster CAs
|
||||
clusterCertPool *x509.CertPool
|
||||
// The TCP addresses we should use for clustering
|
||||
clusterListenerAddrs []*net.TCPAddr
|
||||
// The setup function that gives us the handler to use
|
||||
|
@ -378,16 +378,6 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
|
||||
}
|
||||
|
||||
// Wrap the backend in a cache unless disabled
|
||||
if !conf.DisableCache {
|
||||
_, isCache := conf.Physical.(*physical.Cache)
|
||||
_, isInmem := conf.Physical.(*physical.InmemBackend)
|
||||
if !isCache && !isInmem {
|
||||
cache := physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
conf.Physical = cache
|
||||
}
|
||||
}
|
||||
|
||||
if !conf.DisableMlock {
|
||||
// Ensure our memory usage is locked into physical RAM
|
||||
if err := mlock.LockMemory(); err != nil {
|
||||
|
@ -425,11 +415,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
localClusterCertPool: x509.NewCertPool(),
|
||||
clusterCertPool: x509.NewCertPool(),
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Wrap the backend in a cache unless disabled
|
||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
c.ha = conf.HAPhysical
|
||||
}
|
||||
|
@ -714,7 +709,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
|
|||
|
||||
if !oldAdv {
|
||||
// Ensure we are using current values
|
||||
err = c.loadClusterTLS(adv)
|
||||
err = c.loadLocalClusterTLS(adv)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
@ -1134,8 +1129,10 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
}
|
||||
}()
|
||||
c.logger.Info("core: post-unseal setup starting")
|
||||
if cache, ok := c.physical.(*physical.Cache); ok {
|
||||
cache.Purge()
|
||||
|
||||
// Purge the backend if supported
|
||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
}
|
||||
// HA mode requires us to handle keyring rotation and rekeying
|
||||
if c.ha != nil {
|
||||
|
@ -1232,8 +1229,9 @@ func (c *Core) preSeal() error {
|
|||
if err := c.unloadMounts(); err != nil {
|
||||
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
|
||||
}
|
||||
if cache, ok := c.physical.(*physical.Cache); ok {
|
||||
cache.Purge()
|
||||
// Purge the backend if supported
|
||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
}
|
||||
c.logger.Info("core: pre-seal teardown complete")
|
||||
return result
|
||||
|
|
|
@ -119,7 +119,7 @@ func (m *ExpirationManager) Restore() error {
|
|||
defer m.pendingLock.Unlock()
|
||||
|
||||
// Accumulate existing leases
|
||||
existing, err := CollectKeys(m.idView)
|
||||
existing, err := logical.CollectKeys(m.idView)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan for leases: %v", err)
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error
|
|||
|
||||
// Accumulate existing leases
|
||||
sub := m.idView.SubView(prefix)
|
||||
existing, err := CollectKeys(sub)
|
||||
existing, err := logical.CollectKeys(sub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan for leases: %v", err)
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ func (b *CubbyholeBackend) revoke(saltedToken string) error {
|
|||
return fmt.Errorf("cubbyhole: client token empty during revocation")
|
||||
}
|
||||
|
||||
if err := ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil {
|
||||
if err := logical.ClearView(b.storageView.(*BarrierView).SubView(saltedToken + "/")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -261,7 +261,7 @@ func (c *Core) unmount(path string) (bool, error) {
|
|||
}
|
||||
|
||||
// Clear the data in the view
|
||||
if err := ClearView(view); err != nil {
|
||||
if err := logical.ClearView(view); err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
|
|||
}
|
||||
|
||||
// View should be empty
|
||||
out, err := CollectKeys(view)
|
||||
out, err := logical.CollectKeys(view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -305,7 +305,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
|
|||
}
|
||||
|
||||
// View should not be empty
|
||||
out, err := CollectKeys(view)
|
||||
out, err := logical.CollectKeys(view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -275,7 +275,7 @@ func (ps *PolicyStore) ListPolicies() ([]string, error) {
|
|||
defer metrics.MeasureSince([]string{"policy", "list_policies"}, time.Now())
|
||||
// Scan the view, since the policy names are the same as the
|
||||
// key names.
|
||||
keys, err := CollectKeys(ps.view)
|
||||
keys, err := logical.CollectKeys(ps.view)
|
||||
|
||||
for _, nonAssignable := range nonAssignablePolicies {
|
||||
deleteIndex := -1
|
||||
|
|
|
@ -225,7 +225,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
|
|||
// the TLS state.
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
c.rpcClientConnCancelFunc = cancelFunc
|
||||
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer()), grpc.WithInsecure())
|
||||
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
c.logger.Error("core/refreshRequestForwardingConnection: err setting up rpc client", "error", err)
|
||||
return err
|
||||
|
@ -330,14 +330,18 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
|
|||
// getGRPCDialer is used to return a dialer that has the correct TLS
|
||||
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
|
||||
// NextProtos.
|
||||
func (c *Core) getGRPCDialer() func(string, time.Duration) (net.Conn, error) {
|
||||
func (c *Core) getGRPCDialer(alpnProto, serverName string) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
tlsConfig, err := c.ClusterTLSConfig()
|
||||
if err != nil {
|
||||
c.logger.Error("core/getGRPCDialer: failed to get tls configuration", "error", err)
|
||||
c.logger.Error("core: failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.NextProtos = []string{"req_fw_sb-act_v1"}
|
||||
if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
|
||||
tlsConfig.NextProtos = []string{alpnProto}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue