Add sealunwrapper to ease OSS downgrades (#3936)
This commit is contained in:
parent
847e499261
commit
96ea0620fd
|
@ -380,6 +380,9 @@ type Core struct {
|
|||
// going to be shut down, stepped down, or sealed
|
||||
activeContext context.Context
|
||||
activeContextCancelFunc context.CancelFunc
|
||||
|
||||
// Stores the sealunwrapper for downgrade needs
|
||||
sealUnwrapper physical.Backend
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -517,13 +520,15 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
}
|
||||
c.seal.SetCore(c)
|
||||
|
||||
c.sealUnwrapper = NewSealUnwrapper(phys, conf.Logger)
|
||||
|
||||
var ok bool
|
||||
|
||||
// Wrap the physical backend in a cache layer if enabled
|
||||
if txnOK {
|
||||
c.physical = physical.NewTransactionalCache(phys, conf.CacheSize, conf.Logger)
|
||||
c.physical = physical.NewTransactionalCache(c.sealUnwrapper, conf.CacheSize, conf.Logger)
|
||||
} else {
|
||||
c.physical = physical.NewCache(phys, conf.CacheSize, conf.Logger)
|
||||
c.physical = physical.NewCache(c.sealUnwrapper, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
c.physicalCache = c.physical.(physical.ToggleablePurgemonster)
|
||||
|
||||
|
@ -1580,6 +1585,13 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
c.physicalCache.SetEnabled(true)
|
||||
}
|
||||
|
||||
switch c.sealUnwrapper.(type) {
|
||||
case *sealUnwrapper:
|
||||
c.sealUnwrapper.(*sealUnwrapper).runUnwraps()
|
||||
case *transactionalSealUnwrapper:
|
||||
c.sealUnwrapper.(*transactionalSealUnwrapper).runUnwraps()
|
||||
}
|
||||
|
||||
// Purge these for safety in case of a rekey
|
||||
c.seal.SetBarrierConfig(c.activeContext, nil)
|
||||
if c.seal.RecoveryKeySupported() {
|
||||
|
@ -1685,6 +1697,13 @@ func (c *Core) preSeal() error {
|
|||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
switch c.sealUnwrapper.(type) {
|
||||
case *sealUnwrapper:
|
||||
c.sealUnwrapper.(*sealUnwrapper).stopUnwraps()
|
||||
case *transactionalSealUnwrapper:
|
||||
c.sealUnwrapper.(*transactionalSealUnwrapper).stopUnwraps()
|
||||
}
|
||||
|
||||
// Purge the cache
|
||||
c.physicalCache.SetEnabled(false)
|
||||
c.physicalCache.Purge(c.activeContext)
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
// +build !ent
|
||||
// +build !prem
|
||||
// +build !pro
|
||||
// +build !hsm
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
"github.com/hashicorp/vault/helper/locksutil"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
// NewSealUnwrapper creates a new seal unwrapper
|
||||
func NewSealUnwrapper(underlying physical.Backend, logger log.Logger) physical.Backend {
|
||||
ret := &sealUnwrapper{
|
||||
underlying: underlying,
|
||||
logger: logger,
|
||||
locks: locksutil.CreateLocks(),
|
||||
allowUnwraps: new(uint32),
|
||||
}
|
||||
|
||||
if underTxn, ok := underlying.(physical.Transactional); ok {
|
||||
return &transactionalSealUnwrapper{
|
||||
sealUnwrapper: ret,
|
||||
Transactional: underTxn,
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
var _ physical.Backend = (*sealUnwrapper)(nil)
|
||||
var _ physical.Transactional = (*transactionalSealUnwrapper)(nil)
|
||||
|
||||
type sealUnwrapper struct {
|
||||
underlying physical.Backend
|
||||
logger log.Logger
|
||||
locks []*locksutil.LockEntry
|
||||
allowUnwraps *uint32
|
||||
}
|
||||
|
||||
// transactionalSealUnwrapper is a seal unwrapper that wraps a physical that is transactional
|
||||
type transactionalSealUnwrapper struct {
|
||||
*sealUnwrapper
|
||||
physical.Transactional
|
||||
}
|
||||
|
||||
func (d *sealUnwrapper) Put(ctx context.Context, entry *physical.Entry) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
locksutil.LockForKey(d.locks, entry.Key).Lock()
|
||||
defer locksutil.LockForKey(d.locks, entry.Key).Unlock()
|
||||
|
||||
return d.underlying.Put(ctx, entry)
|
||||
}
|
||||
|
||||
func (d *sealUnwrapper) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
||||
entry, err := d.underlying.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var performUnwrap bool
|
||||
se := &physical.SealWrapEntry{}
|
||||
// If the value ends in our canary value, try to decode the bytes.
|
||||
eLen := len(entry.Value)
|
||||
if eLen > 0 && entry.Value[eLen-1] == 's' {
|
||||
if err := proto.Unmarshal(entry.Value[:eLen-1], se); err == nil {
|
||||
// We unmarshaled successfully which means we need to store it as a
|
||||
// non-proto message
|
||||
performUnwrap = true
|
||||
}
|
||||
}
|
||||
if !performUnwrap {
|
||||
return entry, nil
|
||||
}
|
||||
// It's actually encrypted and we can't read it
|
||||
if se.Wrapped {
|
||||
return nil, fmt.Errorf("cannot decode sealwrapped storage entry %s", entry.Key)
|
||||
}
|
||||
if atomic.LoadUint32(d.allowUnwraps) != 1 {
|
||||
return &physical.Entry{
|
||||
Key: entry.Key,
|
||||
Value: se.Ciphertext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
locksutil.LockForKey(d.locks, key).Lock()
|
||||
defer locksutil.LockForKey(d.locks, key).Unlock()
|
||||
|
||||
// At this point we need to re-read and re-check
|
||||
entry, err = d.underlying.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
performUnwrap = false
|
||||
se = &physical.SealWrapEntry{}
|
||||
// If the value ends in our canary value, try to decode the bytes.
|
||||
eLen = len(entry.Value)
|
||||
if eLen > 0 && entry.Value[eLen-1] == 's' {
|
||||
// We ignore an error because the canary is not a guarantee; if it
|
||||
// doesn't decode, proceed normally
|
||||
if err := proto.Unmarshal(entry.Value[:eLen-1], se); err == nil {
|
||||
// We unmarshaled successfully which means we need to store it as a
|
||||
// non-proto message
|
||||
performUnwrap = true
|
||||
}
|
||||
}
|
||||
if !performUnwrap {
|
||||
return entry, nil
|
||||
}
|
||||
if se.Wrapped {
|
||||
return nil, fmt.Errorf("cannot decode sealwrapped storage entry %s", entry.Key)
|
||||
}
|
||||
|
||||
entry = &physical.Entry{
|
||||
Key: entry.Key,
|
||||
Value: se.Ciphertext,
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(d.allowUnwraps) != 1 {
|
||||
return entry, nil
|
||||
}
|
||||
return entry, d.underlying.Put(ctx, entry)
|
||||
}
|
||||
|
||||
func (d *sealUnwrapper) Delete(ctx context.Context, key string) error {
|
||||
locksutil.LockForKey(d.locks, key).Lock()
|
||||
defer locksutil.LockForKey(d.locks, key).Unlock()
|
||||
|
||||
return d.underlying.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (d *sealUnwrapper) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
return d.underlying.List(ctx, prefix)
|
||||
}
|
||||
|
||||
func (d *transactionalSealUnwrapper) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
|
||||
// Collect keys that need to be locked
|
||||
var keys []string
|
||||
for _, curr := range txns {
|
||||
keys = append(keys, curr.Entry.Key)
|
||||
}
|
||||
// Lock the keys
|
||||
for _, l := range locksutil.LocksForKeys(d.locks, keys) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
}
|
||||
|
||||
if err := d.Transactional.Transaction(ctx, txns); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// This should only run during preSeal which ensures that it can't be run
|
||||
// concurrently and that it will be run only by the active node
|
||||
func (d *sealUnwrapper) stopUnwraps() {
|
||||
atomic.StoreUint32(d.allowUnwraps, 0)
|
||||
}
|
||||
|
||||
func (d *sealUnwrapper) runUnwraps() {
|
||||
// Allow key unwraps on key gets. This gets set only when running on the
|
||||
// active node to prevent standbys from changing data underneath the
|
||||
// primary
|
||||
atomic.StoreUint32(d.allowUnwraps, 1)
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
// +build !ent
|
||||
// +build !prem
|
||||
// +build !pro
|
||||
// +build !hsm
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/logbridge"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
"github.com/hashicorp/vault/physical/inmem"
|
||||
)
|
||||
|
||||
func TestSealUnwrapper(t *testing.T) {
|
||||
logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{
|
||||
Mutex: &sync.Mutex{},
|
||||
}))
|
||||
|
||||
// Test without transactions
|
||||
phys, err := inmem.NewInmemHA(nil, logger.LogxiLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
performTestSealUnwrapper(t, phys, logger)
|
||||
|
||||
// Test with transactions
|
||||
tPhys, err := inmem.NewTransactionalInmemHA(nil, logger.LogxiLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
performTestSealUnwrapper(t, tPhys, logger)
|
||||
}
|
||||
|
||||
func performTestSealUnwrapper(t *testing.T, phys physical.Backend, logger *logbridge.Logger) {
|
||||
ctx := context.Background()
|
||||
base := &CoreConfig{
|
||||
Physical: phys,
|
||||
}
|
||||
cluster := NewTestCluster(t, base, &TestClusterOptions{
|
||||
RawLogger: logger,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
// Read a value and then save it back in a proto message
|
||||
entry, err := phys.Get(ctx, "core/master")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entry.Value) == 0 {
|
||||
t.Fatal("got no value for master")
|
||||
}
|
||||
// Save the original for comparison later
|
||||
origBytes := make([]byte, len(entry.Value))
|
||||
copy(origBytes, entry.Value)
|
||||
se := &physical.SealWrapEntry{
|
||||
Ciphertext: entry.Value,
|
||||
}
|
||||
seb, err := proto.Marshal(se)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Write the canary
|
||||
entry.Value = append(seb, 's')
|
||||
// Save the protobuf value for comparison later
|
||||
pBytes := make([]byte, len(entry.Value))
|
||||
copy(pBytes, entry.Value)
|
||||
if err = phys.Put(ctx, entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// At this point we should be able to read through the standby cores,
|
||||
// successfully decode it, but be able to unmarshal it when read back from
|
||||
// the underlying physical store. When we read from active, it should both
|
||||
// successfully decode it and persist it back.
|
||||
checkValue := func(core *Core, wrapped bool) {
|
||||
entry, err := core.physical.Get(ctx, "core/master")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(entry.Value, origBytes) {
|
||||
t.Fatalf("mismatched original bytes and unwrapped entry bytes: %v vs %v", entry.Value, origBytes)
|
||||
}
|
||||
underlyingEntry, err := phys.Get(ctx, "core/master")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch wrapped {
|
||||
case true:
|
||||
if !bytes.Equal(underlyingEntry.Value, pBytes) {
|
||||
t.Fatalf("mismatched original bytes and proto entry bytes: %v vs %v", underlyingEntry.Value, pBytes)
|
||||
}
|
||||
default:
|
||||
if !bytes.Equal(underlyingEntry.Value, origBytes) {
|
||||
t.Fatalf("mismatched original bytes and unwrapped entry bytes: %v vs %v", underlyingEntry.Value, origBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TestWaitActive(t, cluster.Cores[0].Core)
|
||||
checkValue(cluster.Cores[2].Core, true)
|
||||
checkValue(cluster.Cores[1].Core, true)
|
||||
checkValue(cluster.Cores[0].Core, false)
|
||||
}
|
Loading…
Reference in New Issue