diff --git a/vault/barrier_view.go b/vault/barrier_view.go index 0f3940e5f..41214b56f 100644 --- a/vault/barrier_view.go +++ b/vault/barrier_view.go @@ -1,6 +1,7 @@ package vault import ( + "fmt" "strings" "github.com/hashicorp/vault/logical" @@ -80,3 +81,30 @@ func (v *BarrierView) expandKey(suffix string) string { func (v *BarrierView) truncateKey(full string) string { return strings.TrimPrefix(full, v.prefix) } + +// ScanView is used to scan all the keys in a view recursively +func ScanView(view *BarrierView, cb func(path string)) error { + frontier := []string{""} + for len(frontier) > 0 { + n := len(frontier) + current := frontier[n-1] + frontier = frontier[:n-1] + + // List the contents + contents, err := view.List(current) + if err != nil { + return fmt.Errorf("list failed at path '%s': %v", current, err) + } + + // Handle the contents in the directory + for _, c := range contents { + fullPath := current + c + if strings.HasSuffix(c, "/") { + frontier = append(frontier, fullPath) + } else { + cb(fullPath) + } + } + } + return nil +} diff --git a/vault/barrier_view_test.go b/vault/barrier_view_test.go index 5b5ca6911..71fe94e23 100644 --- a/vault/barrier_view_test.go +++ b/vault/barrier_view_test.go @@ -1,6 +1,8 @@ package vault import ( + "reflect" + "sort" "testing" "github.com/hashicorp/vault/logical" @@ -143,3 +145,41 @@ func TestBarrierView_SubView(t *testing.T) { t.Fatalf("nested foo/bar/test should be gone") } } + +func TestBarrierView_Scan(t *testing.T) { + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "view/") + + expect := []string{} + ent := []*logical.StorageEntry{ + &logical.StorageEntry{Key: "foo", Value: []byte("test")}, + &logical.StorageEntry{Key: "zip", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/zap", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar/baz", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar/zoo", Value: []byte("test")}, + } + + for _, e := range ent { + expect = append(expect, e.Key) + if err := view.Put(e); err != nil { + t.Fatalf("err: %v", err) + } + } + + var out []string + cb := func(path string) { + out = append(out, path) + } + + // Collect the keys + if err := ScanView(view, cb); err != nil { + t.Fatalf("err: %v", err) + } + + sort.Strings(out) + sort.Strings(expect) + if !reflect.DeepEqual(out, expect) { + t.Fatalf("out: %v expect: %v", out, expect) + } +}