diff --git a/acl/acl.go b/acl/acl.go index 8a73674d0..cfe1ff6b6 100644 --- a/acl/acl.go +++ b/acl/acl.go @@ -135,7 +135,7 @@ type PolicyACL struct { keyRules *radix.Tree // serviceRules contains the service policies - serviceRules map[string]string + serviceRules *radix.Tree } // New is used to construct a policy based ACL from a set of policies @@ -144,7 +144,7 @@ func New(parent ACL, policy *Policy) (*PolicyACL, error) { p := &PolicyACL{ parent: parent, keyRules: radix.New(), - serviceRules: make(map[string]string, len(policy.Services)), + serviceRules: radix.New(), } // Load the key policy @@ -154,7 +154,7 @@ func New(parent ACL, policy *Policy) (*PolicyACL, error) { // Load the service policy for _, sp := range policy.Services { - p.serviceRules[sp.Name] = sp.Policy + p.serviceRules.Insert(sp.Name, sp.Policy) } return p, nil } @@ -231,10 +231,8 @@ func (p *PolicyACL) KeyWritePrefix(prefix string) bool { // ServiceRead checks if reading (discovery) of a service is allowed func (p *PolicyACL) ServiceRead(name string) bool { // Check for an exact rule or catch-all - rule, ok := p.serviceRules[name] - if !ok { - rule, ok = p.serviceRules[""] - } + _, rule, ok := p.serviceRules.LongestPrefix(name) + if ok { switch rule { case ServicePolicyWrite: @@ -253,10 +251,8 @@ func (p *PolicyACL) ServiceRead(name string) bool { // ServiceWrite checks if writing (registering) a service is allowed func (p *PolicyACL) ServiceWrite(name string) bool { // Check for an exact rule or catch-all - rule, ok := p.serviceRules[name] - if !ok { - rule, ok = p.serviceRules[""] - } + _, rule, ok := p.serviceRules.LongestPrefix(name) + if ok { switch rule { case ServicePolicyWrite: diff --git a/acl/acl_test.go b/acl/acl_test.go index cecc870b9..d6da2f93e 100644 --- a/acl/acl_test.go +++ b/acl/acl_test.go @@ -127,6 +127,10 @@ func TestPolicyACL(t *testing.T) { Name: "bar", Policy: ServicePolicyDeny, }, + &ServicePolicy{ + Name: "barfoo", + Policy: ServicePolicyWrite, + }, }, } acl, err := New(all, policy) @@ -171,6 +175,10 @@ func TestPolicyACL(t *testing.T) { {"other", true, true}, {"foo", true, false}, {"bar", false, false}, + {"foobar", true, false}, + {"barfo", false, false}, + {"barfoo", true, true}, + {"barfoo2", true, true}, } for _, c := range scases { if c.read != acl.ServiceRead(c.inp) {