diff --git a/acl/acl.go b/acl/acl.go index 442837340..8a73674d0 100644 --- a/acl/acl.go +++ b/acl/acl.go @@ -46,6 +46,12 @@ type ACL interface { // that deny a write. KeyWritePrefix(string) bool + // ServiceWrite checks for permission to read a given service + ServiceWrite(string) bool + + // ServiceRead checks for permission to read a given service + ServiceRead(string) bool + // ACLList checks for permission to list all the ACLs ACLList() bool @@ -73,6 +79,14 @@ func (s *StaticACL) KeyWritePrefix(string) bool { return s.defaultAllow } +func (s *StaticACL) ServiceRead(string) bool { + return s.defaultAllow +} + +func (s *StaticACL) ServiceWrite(string) bool { + return s.defaultAllow +} + func (s *StaticACL) ACLList() bool { return s.allowManage } @@ -119,20 +133,29 @@ type PolicyACL struct { // keyRules contains the key policies keyRules *radix.Tree + + // serviceRules contains the service policies + serviceRules map[string]string } // New is used to construct a policy based ACL from a set of policies // and a parent policy to resolve missing cases. func New(parent ACL, policy *Policy) (*PolicyACL, error) { p := &PolicyACL{ - parent: parent, - keyRules: radix.New(), + parent: parent, + keyRules: radix.New(), + serviceRules: make(map[string]string, len(policy.Services)), } // Load the key policy for _, kp := range policy.Keys { p.keyRules.Insert(kp.Prefix, kp.Policy) } + + // Load the service policy + for _, sp := range policy.Services { + p.serviceRules[sp.Name] = sp.Policy + } return p, nil } @@ -205,6 +228,48 @@ func (p *PolicyACL) KeyWritePrefix(prefix string) bool { return p.parent.KeyWritePrefix(prefix) } +// ServiceRead checks if reading (discovery) of a service is allowed +func (p *PolicyACL) ServiceRead(name string) bool { + // Check for an exact rule or catch-all + rule, ok := p.serviceRules[name] + if !ok { + rule, ok = p.serviceRules[""] + } + if ok { + switch rule { + case ServicePolicyWrite: + return true + case ServicePolicyRead: + return true + default: + return false + } + } + + // No matching rule, use the parent. + return p.parent.ServiceRead(name) +} + +// ServiceWrite checks if writing (registering) a service is allowed +func (p *PolicyACL) ServiceWrite(name string) bool { + // Check for an exact rule or catch-all + rule, ok := p.serviceRules[name] + if !ok { + rule, ok = p.serviceRules[""] + } + if ok { + switch rule { + case ServicePolicyWrite: + return true + default: + return false + } + } + + // No matching rule, use the parent. + return p.parent.ServiceWrite(name) +} + // ACLList checks if listing of ACLs is allowed func (p *PolicyACL) ACLList() bool { return p.parent.ACLList() diff --git a/acl/acl_test.go b/acl/acl_test.go index 9be0388db..cecc870b9 100644 --- a/acl/acl_test.go +++ b/acl/acl_test.go @@ -41,6 +41,12 @@ func TestStaticACL(t *testing.T) { if !all.KeyWrite("foobar") { t.Fatalf("should allow") } + if !all.ServiceRead("foobar") { + t.Fatalf("should allow") + } + if !all.ServiceWrite("foobar") { + t.Fatalf("should allow") + } if all.ACLList() { t.Fatalf("should not allow") } @@ -54,6 +60,12 @@ func TestStaticACL(t *testing.T) { if none.KeyWrite("foobar") { t.Fatalf("should not allow") } + if none.ServiceRead("foobar") { + t.Fatalf("should not allow") + } + if none.ServiceWrite("foobar") { + t.Fatalf("should not allow") + } if none.ACLList() { t.Fatalf("should not noneow") } @@ -67,6 +79,12 @@ func TestStaticACL(t *testing.T) { if !manage.KeyWrite("foobar") { t.Fatalf("should allow") } + if !manage.ServiceRead("foobar") { + t.Fatalf("should allow") + } + if !manage.ServiceWrite("foobar") { + t.Fatalf("should allow") + } if !manage.ACLList() { t.Fatalf("should allow") } @@ -96,19 +114,33 @@ func TestPolicyACL(t *testing.T) { Policy: KeyPolicyRead, }, }, + Services: []*ServicePolicy{ + &ServicePolicy{ + Name: "", + Policy: ServicePolicyWrite, + }, + &ServicePolicy{ + Name: "foo", + Policy: ServicePolicyRead, + }, + &ServicePolicy{ + Name: "bar", + Policy: ServicePolicyDeny, + }, + }, } acl, err := New(all, policy) if err != nil { t.Fatalf("err: %v", err) } - type tcase struct { + type keycase struct { inp string read bool write bool writePrefix bool } - cases := []tcase{ + cases := []keycase{ {"other", true, true, true}, {"foo/test", true, true, true}, {"foo/priv/test", false, false, false}, @@ -128,6 +160,26 @@ func TestPolicyACL(t *testing.T) { t.Fatalf("Write prefix fail: %#v", c) } } + + // Test the services + type servicecase struct { + inp string + read bool + write bool + } + scases := []servicecase{ + {"other", true, true}, + {"foo", true, false}, + {"bar", false, false}, + } + for _, c := range scases { + if c.read != acl.ServiceRead(c.inp) { + t.Fatalf("Read fail: %#v", c) + } + if c.write != acl.ServiceWrite(c.inp) { + t.Fatalf("Write fail: %#v", c) + } + } } func TestPolicyACL_Parent(t *testing.T) { @@ -143,6 +195,16 @@ func TestPolicyACL_Parent(t *testing.T) { Policy: KeyPolicyRead, }, }, + Services: []*ServicePolicy{ + &ServicePolicy{ + Name: "other", + Policy: ServicePolicyWrite, + }, + &ServicePolicy{ + Name: "foo", + Policy: ServicePolicyRead, + }, + }, } root, err := New(deny, policyRoot) if err != nil { @@ -164,19 +226,25 @@ func TestPolicyACL_Parent(t *testing.T) { Policy: KeyPolicyRead, }, }, + Services: []*ServicePolicy{ + &ServicePolicy{ + Name: "bar", + Policy: ServicePolicyDeny, + }, + }, } acl, err := New(root, policy) if err != nil { t.Fatalf("err: %v", err) } - type tcase struct { + type keycase struct { inp string read bool write bool writePrefix bool } - cases := []tcase{ + cases := []keycase{ {"other", false, false, false}, {"foo/test", true, true, true}, {"foo/priv/test", true, false, false}, @@ -194,4 +262,25 @@ func TestPolicyACL_Parent(t *testing.T) { t.Fatalf("Write prefix fail: %#v", c) } } + + // Test the services + type servicecase struct { + inp string + read bool + write bool + } + scases := []servicecase{ + {"fail", false, false}, + {"other", true, true}, + {"foo", true, false}, + {"bar", false, false}, + } + for _, c := range scases { + if c.read != acl.ServiceRead(c.inp) { + t.Fatalf("Read fail: %#v", c) + } + if c.write != acl.ServiceWrite(c.inp) { + t.Fatalf("Write fail: %#v", c) + } + } }