diff --git a/acl/cache.go b/acl/cache.go new file mode 100644 index 000000000..ad744fb8f --- /dev/null +++ b/acl/cache.go @@ -0,0 +1,89 @@ +package acl + +import ( + "crypto/md5" + "fmt" + + "github.com/hashicorp/golang-lru" +) + +// FaultFunc is a function used to fault in the rules for an +// ACL given it's ID +type FaultFunc func(id string) (string, error) + +// Cache is used to implement policy and ACL caching +type Cache struct { + aclCache *lru.Cache + faultfn FaultFunc + parent ACL + policyCache *lru.Cache +} + +// NewCache contructs a new policy and ACL cache of a given size +func NewCache(size int, parent ACL, faultfn FaultFunc) (*Cache, error) { + if size <= 0 { + return nil, fmt.Errorf("Must provide positive cache size") + } + pc, _ := lru.New(size) + ac, _ := lru.New(size) + c := &Cache{ + aclCache: ac, + faultfn: faultfn, + parent: parent, + policyCache: pc, + } + return c, nil +} + +// GetPolicy is used to get a potentially cached policy set. +// If not cached, it will be parsed, and then cached. +func (c *Cache) GetPolicy(rules string) (*Policy, error) { + hash := fmt.Sprintf("%x", md5.Sum([]byte(rules))) + raw, ok := c.policyCache.Get(hash) + if ok { + return raw.(*Policy), nil + } + policy, err := Parse(rules) + if err != nil { + return nil, err + } + c.policyCache.Add(hash, policy) + return policy, nil +} + +// GetACL is used to get a potentially cached ACL policy. +// If not cached, it will be generated and then cached. +func (c *Cache) GetACL(id string) (ACL, error) { + // Look for the ACL directly + raw, ok := c.aclCache.Get(id) + if ok { + return raw.(ACL), nil + } + + // Get the rules + rules, err := c.faultfn(id) + if err != nil { + return nil, err + } + + // Get the policy + policy, err := c.GetPolicy(rules) + if err != nil { + return nil, err + } + + // Get the ACL + acl, err := New(c.parent, policy) + if err != nil { + return nil, err + } + + // Cache and return the ACL + c.aclCache.Add(id, acl) + return acl, nil +} + +// ClearACL is used to clear the ACL cache if any +func (c *Cache) ClearACL(id string) { + c.aclCache.Remove(id) +} diff --git a/acl/cache_test.go b/acl/cache_test.go new file mode 100644 index 000000000..47602d639 --- /dev/null +++ b/acl/cache_test.go @@ -0,0 +1,130 @@ +package acl + +import ( + "testing" +) + +func TestCache_GetPolicy(t *testing.T) { + c, err := NewCache(1, AllowAll(), nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + p, err := c.GetPolicy("") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Should get the same policy + p1, err := c.GetPolicy("") + if err != nil { + t.Fatalf("err: %v", err) + } + if p != p1 { + t.Fatalf("should be cached") + } + + // Cache a new policy + _, err = c.GetPolicy(testSimplePolicy) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Test invalidation of p + p3, err := c.GetPolicy("") + if err != nil { + t.Fatalf("err: %v", err) + } + if p == p3 { + t.Fatalf("should be not cached") + } +} + +func TestCache_GetACL(t *testing.T) { + policies := map[string]string{ + "foo": testSimplePolicy, + "bar": testSimplePolicy, + } + faultfn := func(id string) (string, error) { + return policies[id], nil + } + + c, err := NewCache(1, DenyAll(), faultfn) + if err != nil { + t.Fatalf("err: %v", err) + } + + acl, err := c.GetACL("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if acl.KeyRead("bar/test") { + t.Fatalf("should deny") + } + if !acl.KeyRead("foo/test") { + t.Fatalf("should allow") + } + + acl2, err := c.GetACL("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if acl != acl2 { + t.Fatalf("should be cached") + } + + // Invalidate cache + _, err = c.GetACL("bar") + if err != nil { + t.Fatalf("err: %v", err) + } + + acl3, err := c.GetACL("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if acl == acl3 { + t.Fatalf("should not be cached") + } +} + +func TestCache_ClearACL(t *testing.T) { + policies := map[string]string{ + "foo": testSimplePolicy, + "bar": testSimplePolicy, + } + faultfn := func(id string) (string, error) { + return policies[id], nil + } + + c, err := NewCache(1, DenyAll(), faultfn) + if err != nil { + t.Fatalf("err: %v", err) + } + + acl, err := c.GetACL("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Nuke the cache + c.ClearACL("foo") + + acl2, err := c.GetACL("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if acl == acl2 { + t.Fatalf("should not be cached") + } +} + +var testSimplePolicy = ` +key "foo/" { + policy = "read" +} +` diff --git a/acl/policy.go b/acl/policy.go index 2abdc9812..df4af9c06 100644 --- a/acl/policy.go +++ b/acl/policy.go @@ -32,6 +32,11 @@ type KeyPolicy struct { func Parse(rules string) (*Policy, error) { // Decode the rules p := &Policy{} + if rules == "" { + // Hot path for empty rules + return p, nil + } + if err := hcl.Decode(p, rules); err != nil { return nil, fmt.Errorf("Failed to parse ACL rules: %v", err) }