diff --git a/logical/storage_inmem.go b/logical/storage_inmem.go index 03e60118a..e0ff75f14 100644 --- a/logical/storage_inmem.go +++ b/logical/storage_inmem.go @@ -2,10 +2,10 @@ package logical import ( "context" - "strings" "sync" - radix "github.com/armon/go-radix" + "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" ) // InmemStorage implements Storage and stores all data in memory. It is @@ -13,79 +13,55 @@ import ( // having to load all of physical's dependencies (which are legion) just to // have some testing storage. type InmemStorage struct { - sync.RWMutex - root *radix.Tree - once sync.Once + underlying physical.Backend + once sync.Once } func (s *InmemStorage) Get(ctx context.Context, key string) (*StorageEntry, error) { s.once.Do(s.init) - s.RLock() - defer s.RUnlock() - - if raw, ok := s.root.Get(key); ok { - se := raw.(*StorageEntry) - return &StorageEntry{ - Key: se.Key, - Value: se.Value, - }, nil + entry, err := s.underlying.Get(ctx, key) + if err != nil { + return nil, err } - - return nil, nil + if entry == nil { + return nil, nil + } + return &StorageEntry{ + Key: entry.Key, + Value: entry.Value, + SealWrap: entry.SealWrap, + }, nil } func (s *InmemStorage) Put(ctx context.Context, entry *StorageEntry) error { s.once.Do(s.init) - s.Lock() - defer s.Unlock() - - s.root.Insert(entry.Key, &StorageEntry{ - Key: entry.Key, - Value: entry.Value, + return s.underlying.Put(ctx, &physical.Entry{ + Key: entry.Key, + Value: entry.Value, + SealWrap: entry.SealWrap, }) - return nil } func (s *InmemStorage) Delete(ctx context.Context, key string) error { s.once.Do(s.init) - s.Lock() - defer s.Unlock() - - s.root.Delete(key) - return nil + return s.underlying.Delete(ctx, key) } func (s *InmemStorage) List(ctx context.Context, prefix string) ([]string, error) { s.once.Do(s.init) - s.RLock() - defer s.RUnlock() + return s.underlying.List(ctx, prefix) +} - var out []string - seen := make(map[string]interface{}) - walkFn := func(s string, v interface{}) bool { - trimmed := strings.TrimPrefix(s, prefix) - sep := strings.Index(trimmed, "/") - if sep == -1 { - out = append(out, trimmed) - } else { - trimmed = trimmed[:sep+1] - if _, ok := seen[trimmed]; !ok { - out = append(out, trimmed) - seen[trimmed] = struct{}{} - } - } - return false - } - s.root.WalkPrefix(prefix, walkFn) - - return out, nil +func (s *InmemStorage) Underlying() *inmem.InmemBackend { + s.once.Do(s.init) + return s.underlying.(*inmem.InmemBackend) } func (s *InmemStorage) init() { - s.root = radix.New() + s.underlying, _ = inmem.NewInmem(nil, nil) } diff --git a/physical/inmem/inmem.go b/physical/inmem/inmem.go index 4501bab32..4bd577277 100644 --- a/physical/inmem/inmem.go +++ b/physical/inmem/inmem.go @@ -2,8 +2,10 @@ package inmem import ( "context" + "errors" "strings" "sync" + "sync/atomic" "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" @@ -19,6 +21,13 @@ var _ physical.Lock = (*InmemLock)(nil) var _ physical.Transactional = (*TransactionalInmemBackend)(nil) var _ physical.Transactional = (*TransactionalInmemHABackend)(nil) +var ( + PutDisabledError = errors.New("put operations disabled in inmem backend") + GetDisabledError = errors.New("get operations disabled in inmem backend") + DeleteDisabledError = errors.New("delete operations disabled in inmem backend") + ListDisabledError = errors.New("list operations disabled in inmem backend") +) + // InmemBackend is an in-memory only physical backend. It is useful // for testing and development situations where the data is not // expected to be durable. @@ -27,6 +36,10 @@ type InmemBackend struct { root *radix.Tree permitPool *physical.PermitPool logger log.Logger + FailGet *uint32 + FailPut *uint32 + FailDelete *uint32 + FailList *uint32 } type TransactionalInmemBackend struct { @@ -39,6 +52,10 @@ func NewInmem(_ map[string]string, logger log.Logger) (physical.Backend, error) root: radix.New(), permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), logger: logger, + FailGet: new(uint32), + FailPut: new(uint32), + FailDelete: new(uint32), + FailList: new(uint32), } return in, nil } @@ -51,6 +68,10 @@ func NewTransactionalInmem(_ map[string]string, logger log.Logger) (physical.Bac root: radix.New(), permitPool: physical.NewPermitPool(1), logger: logger, + FailGet: new(uint32), + FailPut: new(uint32), + FailDelete: new(uint32), + FailList: new(uint32), }, } return in, nil @@ -68,6 +89,10 @@ func (i *InmemBackend) Put(ctx context.Context, entry *physical.Entry) error { } func (i *InmemBackend) PutInternal(ctx context.Context, entry *physical.Entry) error { + if i.FailPut != nil && atomic.LoadUint32(i.FailPut) != 0 { + return PutDisabledError + } + i.root.Insert(entry.Key, entry.Value) return nil } @@ -84,6 +109,10 @@ func (i *InmemBackend) Get(ctx context.Context, key string) (*physical.Entry, er } func (i *InmemBackend) GetInternal(ctx context.Context, key string) (*physical.Entry, error) { + if i.FailGet != nil && atomic.LoadUint32(i.FailGet) != 0 { + return nil, GetDisabledError + } + if raw, ok := i.root.Get(key); ok { return &physical.Entry{ Key: key, @@ -105,6 +134,10 @@ func (i *InmemBackend) Delete(ctx context.Context, key string) error { } func (i *InmemBackend) DeleteInternal(ctx context.Context, key string) error { + if i.FailDelete != nil && atomic.LoadUint32(i.FailDelete) != 0 { + return DeleteDisabledError + } + i.root.Delete(key) return nil } @@ -122,6 +155,10 @@ func (i *InmemBackend) List(ctx context.Context, prefix string) ([]string, error } func (i *InmemBackend) ListInternal(prefix string) ([]string, error) { + if i.FailList != nil && atomic.LoadUint32(i.FailList) != 0 { + return nil, ListDisabledError + } + var out []string seen := make(map[string]interface{}) walkFn := func(s string, v interface{}) bool {