Add sealunwrapper to ease OSS downgrades (#3936)

This commit is contained in:
Jeff Mitchell 2018-02-09 16:37:40 -05:00 committed by GitHub
parent 847e499261
commit 96ea0620fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 315 additions and 2 deletions

View File

@ -380,6 +380,9 @@ type Core struct {
// going to be shut down, stepped down, or sealed
activeContext context.Context
activeContextCancelFunc context.CancelFunc
// Stores the sealunwrapper for downgrade needs
sealUnwrapper physical.Backend
}
// CoreConfig is used to parameterize a core
@ -517,13 +520,15 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
c.seal.SetCore(c)
c.sealUnwrapper = NewSealUnwrapper(phys, conf.Logger)
var ok bool
// Wrap the physical backend in a cache layer if enabled
if txnOK {
c.physical = physical.NewTransactionalCache(phys, conf.CacheSize, conf.Logger)
c.physical = physical.NewTransactionalCache(c.sealUnwrapper, conf.CacheSize, conf.Logger)
} else {
c.physical = physical.NewCache(phys, conf.CacheSize, conf.Logger)
c.physical = physical.NewCache(c.sealUnwrapper, conf.CacheSize, conf.Logger)
}
c.physicalCache = c.physical.(physical.ToggleablePurgemonster)
@ -1580,6 +1585,13 @@ func (c *Core) postUnseal() (retErr error) {
c.physicalCache.SetEnabled(true)
}
switch c.sealUnwrapper.(type) {
case *sealUnwrapper:
c.sealUnwrapper.(*sealUnwrapper).runUnwraps()
case *transactionalSealUnwrapper:
c.sealUnwrapper.(*transactionalSealUnwrapper).runUnwraps()
}
// Purge these for safety in case of a rekey
c.seal.SetBarrierConfig(c.activeContext, nil)
if c.seal.RecoveryKeySupported() {
@ -1685,6 +1697,13 @@ func (c *Core) preSeal() error {
result = multierror.Append(result, err)
}
switch c.sealUnwrapper.(type) {
case *sealUnwrapper:
c.sealUnwrapper.(*sealUnwrapper).stopUnwraps()
case *transactionalSealUnwrapper:
c.sealUnwrapper.(*transactionalSealUnwrapper).stopUnwraps()
}
// Purge the cache
c.physicalCache.SetEnabled(false)
c.physicalCache.Purge(c.activeContext)

183
vault/sealunwrapper.go Normal file
View File

@ -0,0 +1,183 @@
// +build !ent
// +build !prem
// +build !pro
// +build !hsm
package vault
import (
"context"
"fmt"
"sync/atomic"
proto "github.com/golang/protobuf/proto"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/physical"
log "github.com/mgutz/logxi/v1"
)
// NewSealUnwrapper creates a new seal unwrapper
func NewSealUnwrapper(underlying physical.Backend, logger log.Logger) physical.Backend {
ret := &sealUnwrapper{
underlying: underlying,
logger: logger,
locks: locksutil.CreateLocks(),
allowUnwraps: new(uint32),
}
if underTxn, ok := underlying.(physical.Transactional); ok {
return &transactionalSealUnwrapper{
sealUnwrapper: ret,
Transactional: underTxn,
}
}
return ret
}
var _ physical.Backend = (*sealUnwrapper)(nil)
var _ physical.Transactional = (*transactionalSealUnwrapper)(nil)
type sealUnwrapper struct {
underlying physical.Backend
logger log.Logger
locks []*locksutil.LockEntry
allowUnwraps *uint32
}
// transactionalSealUnwrapper is a seal unwrapper that wraps a physical that is transactional
type transactionalSealUnwrapper struct {
*sealUnwrapper
physical.Transactional
}
func (d *sealUnwrapper) Put(ctx context.Context, entry *physical.Entry) error {
if entry == nil {
return nil
}
locksutil.LockForKey(d.locks, entry.Key).Lock()
defer locksutil.LockForKey(d.locks, entry.Key).Unlock()
return d.underlying.Put(ctx, entry)
}
func (d *sealUnwrapper) Get(ctx context.Context, key string) (*physical.Entry, error) {
entry, err := d.underlying.Get(ctx, key)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var performUnwrap bool
se := &physical.SealWrapEntry{}
// If the value ends in our canary value, try to decode the bytes.
eLen := len(entry.Value)
if eLen > 0 && entry.Value[eLen-1] == 's' {
if err := proto.Unmarshal(entry.Value[:eLen-1], se); err == nil {
// We unmarshaled successfully which means we need to store it as a
// non-proto message
performUnwrap = true
}
}
if !performUnwrap {
return entry, nil
}
// It's actually encrypted and we can't read it
if se.Wrapped {
return nil, fmt.Errorf("cannot decode sealwrapped storage entry %s", entry.Key)
}
if atomic.LoadUint32(d.allowUnwraps) != 1 {
return &physical.Entry{
Key: entry.Key,
Value: se.Ciphertext,
}, nil
}
locksutil.LockForKey(d.locks, key).Lock()
defer locksutil.LockForKey(d.locks, key).Unlock()
// At this point we need to re-read and re-check
entry, err = d.underlying.Get(ctx, key)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
performUnwrap = false
se = &physical.SealWrapEntry{}
// If the value ends in our canary value, try to decode the bytes.
eLen = len(entry.Value)
if eLen > 0 && entry.Value[eLen-1] == 's' {
// We ignore an error because the canary is not a guarantee; if it
// doesn't decode, proceed normally
if err := proto.Unmarshal(entry.Value[:eLen-1], se); err == nil {
// We unmarshaled successfully which means we need to store it as a
// non-proto message
performUnwrap = true
}
}
if !performUnwrap {
return entry, nil
}
if se.Wrapped {
return nil, fmt.Errorf("cannot decode sealwrapped storage entry %s", entry.Key)
}
entry = &physical.Entry{
Key: entry.Key,
Value: se.Ciphertext,
}
if atomic.LoadUint32(d.allowUnwraps) != 1 {
return entry, nil
}
return entry, d.underlying.Put(ctx, entry)
}
func (d *sealUnwrapper) Delete(ctx context.Context, key string) error {
locksutil.LockForKey(d.locks, key).Lock()
defer locksutil.LockForKey(d.locks, key).Unlock()
return d.underlying.Delete(ctx, key)
}
func (d *sealUnwrapper) List(ctx context.Context, prefix string) ([]string, error) {
return d.underlying.List(ctx, prefix)
}
func (d *transactionalSealUnwrapper) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
// Collect keys that need to be locked
var keys []string
for _, curr := range txns {
keys = append(keys, curr.Entry.Key)
}
// Lock the keys
for _, l := range locksutil.LocksForKeys(d.locks, keys) {
l.Lock()
defer l.Unlock()
}
if err := d.Transactional.Transaction(ctx, txns); err != nil {
return err
}
return nil
}
// This should only run during preSeal which ensures that it can't be run
// concurrently and that it will be run only by the active node
func (d *sealUnwrapper) stopUnwraps() {
atomic.StoreUint32(d.allowUnwraps, 0)
}
func (d *sealUnwrapper) runUnwraps() {
// Allow key unwraps on key gets. This gets set only when running on the
// active node to prevent standbys from changing data underneath the
// primary
atomic.StoreUint32(d.allowUnwraps, 1)
}

111
vault/sealunwrapper_test.go Normal file
View File

@ -0,0 +1,111 @@
// +build !ent
// +build !prem
// +build !pro
// +build !hsm
package vault
import (
"bytes"
"context"
"sync"
"testing"
proto "github.com/golang/protobuf/proto"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/logbridge"
"github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/physical/inmem"
)
func TestSealUnwrapper(t *testing.T) {
logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{
Mutex: &sync.Mutex{},
}))
// Test without transactions
phys, err := inmem.NewInmemHA(nil, logger.LogxiLogger())
if err != nil {
t.Fatal(err)
}
performTestSealUnwrapper(t, phys, logger)
// Test with transactions
tPhys, err := inmem.NewTransactionalInmemHA(nil, logger.LogxiLogger())
if err != nil {
t.Fatal(err)
}
performTestSealUnwrapper(t, tPhys, logger)
}
func performTestSealUnwrapper(t *testing.T, phys physical.Backend, logger *logbridge.Logger) {
ctx := context.Background()
base := &CoreConfig{
Physical: phys,
}
cluster := NewTestCluster(t, base, &TestClusterOptions{
RawLogger: logger,
})
cluster.Start()
defer cluster.Cleanup()
// Read a value and then save it back in a proto message
entry, err := phys.Get(ctx, "core/master")
if err != nil {
t.Fatal(err)
}
if len(entry.Value) == 0 {
t.Fatal("got no value for master")
}
// Save the original for comparison later
origBytes := make([]byte, len(entry.Value))
copy(origBytes, entry.Value)
se := &physical.SealWrapEntry{
Ciphertext: entry.Value,
}
seb, err := proto.Marshal(se)
if err != nil {
t.Fatal(err)
}
// Write the canary
entry.Value = append(seb, 's')
// Save the protobuf value for comparison later
pBytes := make([]byte, len(entry.Value))
copy(pBytes, entry.Value)
if err = phys.Put(ctx, entry); err != nil {
t.Fatal(err)
}
// At this point we should be able to read through the standby cores,
// successfully decode it, but be able to unmarshal it when read back from
// the underlying physical store. When we read from active, it should both
// successfully decode it and persist it back.
checkValue := func(core *Core, wrapped bool) {
entry, err := core.physical.Get(ctx, "core/master")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(entry.Value, origBytes) {
t.Fatalf("mismatched original bytes and unwrapped entry bytes: %v vs %v", entry.Value, origBytes)
}
underlyingEntry, err := phys.Get(ctx, "core/master")
if err != nil {
t.Fatal(err)
}
switch wrapped {
case true:
if !bytes.Equal(underlyingEntry.Value, pBytes) {
t.Fatalf("mismatched original bytes and proto entry bytes: %v vs %v", underlyingEntry.Value, pBytes)
}
default:
if !bytes.Equal(underlyingEntry.Value, origBytes) {
t.Fatalf("mismatched original bytes and unwrapped entry bytes: %v vs %v", underlyingEntry.Value, origBytes)
}
}
}
TestWaitActive(t, cluster.Cores[0].Core)
checkValue(cluster.Cores[2].Core, true)
checkValue(cluster.Cores[1].Core, true)
checkValue(cluster.Cores[0].Core, false)
}