Merge pull request #1299 from hashicorp/utility-enhancements

Utility Enhancements
This commit is contained in:
Vishal Nayak 2016-04-05 20:47:40 -04:00
commit dbc8162ae4
18 changed files with 260 additions and 101 deletions

View file

@ -8,7 +8,7 @@ import (
"net"
"strings"
"github.com/hashicorp/vault/helper/policies"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -123,7 +123,7 @@ func (b *backend) pathLoginRenew(
if err != nil {
return nil, err
}
if !policies.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil
}

View file

@ -6,6 +6,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -137,19 +138,13 @@ func (b *backend) pathCertWrite(
name := strings.ToLower(d.Get("name").(string))
certificate := d.Get("certificate").(string)
displayName := d.Get("display_name").(string)
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
policies := policyutil.ParsePolicies(d.Get("policies").(string))
// Default the display name to the certificate name if not given
if displayName == "" {
displayName = name
}
if len(policies) == 0 {
return logical.ErrorResponse("policies required"), nil
}
parsed := parsePEM([]byte(certificate))
if len(parsed) == 0 {
return logical.ErrorResponse("failed to parse certificate"), nil

View file

@ -10,7 +10,7 @@ import (
"strings"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/policies"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(
return nil, nil
}
if !policies.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
}

View file

@ -5,7 +5,7 @@ import (
"net/url"
"github.com/google/go-github/github"
"github.com/hashicorp/vault/helper/policies"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
} else {
verifyResp = verifyResponse
}
if !policies.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil
}

View file

@ -3,6 +3,7 @@ package ldap
import (
"strings"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -25,7 +26,7 @@ func pathGroups(b *backend) *framework.Path {
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.DeleteOperation: b.pathGroupDelete,
logical.ReadOperation: b.pathGroupRead,
logical.UpdateOperation: b.pathGroupWrite,
logical.UpdateOperation: b.pathGroupWrite,
},
HelpSynopsis: pathGroupHelpSyn,
@ -79,15 +80,9 @@ func (b *backend) pathGroupRead(
func (b *backend) pathGroupWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
// Store it
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
Policies: policies,
entry, err := logical.StorageEntryJSON("group/"+d.Get("name").(string), &GroupEntry{
Policies: policyutil.ParsePolicies(d.Get("policies").(string)),
})
if err != nil {
return nil, err

View file

@ -4,7 +4,7 @@ import (
"sort"
"strings"
"github.com/hashicorp/vault/helper/policies"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
return resp, err
}
if !policies.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
}

View file

@ -5,7 +5,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/policies"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/bcrypt"
@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew(
return nil, nil
}
if !policies.EquivalentPolicies(user.Policies, req.Auth.Policies) {
if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
}

View file

@ -2,8 +2,8 @@ package userpass
import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -52,23 +52,11 @@ func (b *backend) pathUserPoliciesUpdate(
return nil, fmt.Errorf("username does not exist")
}
err = b.updateUserPolicies(req, d, userEntry)
if err != nil {
return nil, err
}
userEntry.Policies = policyutil.ParsePolicies(d.Get("policies").(string))
return nil, b.setUser(req.Storage, username, userEntry)
}
func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
userEntry.Policies = policies
return nil
}
const pathUserPoliciesHelpSyn = `
Update the policies associated with the username.
`

View file

@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -137,11 +138,8 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData)
}
}
if _, ok := d.GetOk("policies"); ok {
err = b.updateUserPolicies(req, d, userEntry)
if err != nil {
return nil, err
}
if policiesRaw, ok := d.GetOk("policies"); ok {
userEntry.Policies = policyutil.ParsePolicies(policiesRaw.(string))
}
ttlStr := userEntry.TTL.String()

View file

@ -0,0 +1,87 @@
package policyutil
import (
"sort"
"strings"
)
func ParsePolicies(policiesRaw string) []string {
policies := strings.Split(policiesRaw, ",")
defaultFound := false
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
// If 'root' policy is present, ignore all other policies.
if policies[i] == "root" {
policies = []string{"root"}
defaultFound = true
break
}
if policies[i] == "default" {
defaultFound = true
}
}
// Always add 'default' except only if the policies contain 'root'.
if len(policies) == 0 || !defaultFound {
policies = append(policies, "default")
}
// Sort to make the computations on policies consistent.
sort.Strings(policies)
return policies
}
// ComparePolicies checks whether the given policy sets are equivalent, as in,
// they contain the same values. The benefit of this method is that it leaves
// the "default" policy out of its comparisons as it may be added later by core
// after a set of policies has been saved by a backend.
func EquivalentPolicies(a, b []string) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// First we'll build maps to ensure unique values and filter default
mapA := map[string]bool{}
mapB := map[string]bool{}
for _, keyA := range a {
if keyA == "default" {
continue
}
mapA[keyA] = true
}
for _, keyB := range b {
if keyB == "default" {
continue
}
mapB[keyB] = true
}
// Now we'll build our checking slices
var sortedA, sortedB []string
for keyA, _ := range mapA {
sortedA = append(sortedA, keyA)
}
for keyB, _ := range mapB {
sortedB = append(sortedB, keyB)
}
sort.Strings(sortedA)
sort.Strings(sortedB)
// Finally, compare
if len(sortedA) != len(sortedB) {
return false
}
for i := range sortedA {
if sortedA[i] != sortedB[i] {
return false
}
}
return true
}

View file

@ -0,0 +1,61 @@
package policyutil
import "testing"
func TestParsePolicies(t *testing.T) {
expected := []string{"foo", "bar", "default"}
actual := ParsePolicies("foo,bar")
// add default if not present.
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// do not add default more than once.
actual = ParsePolicies("foo,bar,default")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// handle spaces and tabs.
actual = ParsePolicies(" foo , bar , default")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// ignore all others if root is present.
expected = []string{"root"}
actual = ParsePolicies("foo,bar,root")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// with spaces and tabs.
expected = []string{"root"}
actual = ParsePolicies("foo ,bar, root ")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
}
func TestEquivalentPolicies(t *testing.T) {
a := []string{"foo", "bar"}
var b []string
if EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"foo"}
if EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"bar", "foo"}
if !EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"foo", "default", "bar"}
if !EquivalentPolicies(a, b) {
t.Fatal("bad")
}
}

22
helper/strutil/strutil.go Normal file
View file

@ -0,0 +1,22 @@
package strutil
// StrListContains looks for a string in a list of strings.
func StrListContains(haystack []string, needle string) bool {
for _, item := range haystack {
if item == needle {
return true
}
}
return false
}
// StrListSubset checks if a given list is a subset
// of another set
func StrListSubset(super, sub []string) bool {
for _, item := range sub {
if !StrListContains(super, item) {
return false
}
}
return true
}

View file

@ -0,0 +1,49 @@
package strutil
import "testing"
func TestStrListContains(t *testing.T) {
haystack := []string{
"dev",
"ops",
"prod",
"root",
}
if StrListContains(haystack, "tubez") {
t.Fatalf("Bad")
}
if !StrListContains(haystack, "root") {
t.Fatalf("Bad")
}
}
func TestStrListSubset(t *testing.T) {
parent := []string{
"dev",
"ops",
"prod",
"root",
}
child := []string{
"prod",
"ops",
}
if !StrListSubset(parent, child) {
t.Fatalf("Bad")
}
if !StrListSubset(parent, parent) {
t.Fatalf("Bad")
}
if !StrListSubset(child, child) {
t.Fatalf("Bad")
}
if !StrListSubset(child, nil) {
t.Fatalf("Bad")
}
if StrListSubset(child, parent) {
t.Fatalf("Bad")
}
if StrListSubset(nil, child) {
t.Fatalf("Bad")
}
}

View file

@ -18,6 +18,7 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/mlock"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/shamir"
@ -595,7 +596,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
}
// Set the default lease if non-provided, root tokens are exempt
if auth.TTL == 0 && !strListContains(auth.Policies, "root") {
if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") {
auth.TTL = sysView.DefaultLeaseTTL()
}
@ -614,7 +615,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
TTL: auth.TTL,
}
if strListSubset(te.Policies, []string{"root"}) {
if strutil.StrListSubset(te.Policies, []string{"root"}) {
te.Policies = []string{"root"}
} else {
// Use a map to filter out/prevent duplicates

View file

@ -12,6 +12,7 @@ import (
"github.com/fatih/structs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/mitchellh/mapstructure"
@ -889,7 +890,7 @@ func (ts *TokenStore) handleCreateCommon(
if len(data.Policies) == 0 {
data.Policies = role.AllowedPolicies
} else {
if !strListSubset(role.AllowedPolicies, data.Policies) {
if !strutil.StrListSubset(role.AllowedPolicies, data.Policies) {
return logical.ErrorResponse("token policies must be subset of the role's allowed policies"), logical.ErrInvalidRequest
}
}
@ -899,7 +900,7 @@ func (ts *TokenStore) handleCreateCommon(
// When a role is not in use, only permit policies to be a subset unless
// the client has root or sudo privileges
case !isSudo && !strListSubset(parent.Policies, data.Policies):
case !isSudo && !strutil.StrListSubset(parent.Policies, data.Policies):
return logical.ErrorResponse("child policies must be subset of parent"), logical.ErrInvalidRequest
}
@ -972,7 +973,7 @@ func (ts *TokenStore) handleCreateCommon(
sysView := ts.System()
// Set the default lease if non-provided, root tokens are exempt
if te.TTL == 0 && !strListContains(te.Policies, "root") {
if te.TTL == 0 && !strutil.StrListContains(te.Policies, "root") {
te.TTL = sysView.DefaultLeaseTTL()
}

View file

@ -16,49 +16,3 @@ func TestRandBytes(t *testing.T) {
t.Fatalf("bad: %v", b)
}
}
func TestStrListContains(t *testing.T) {
haystack := []string{
"dev",
"ops",
"prod",
"root",
}
if strListContains(haystack, "tubez") {
t.Fatalf("Bad")
}
if !strListContains(haystack, "root") {
t.Fatalf("Bad")
}
}
func TestStrListSubset(t *testing.T) {
parent := []string{
"dev",
"ops",
"prod",
"root",
}
child := []string{
"prod",
"ops",
}
if !strListSubset(parent, child) {
t.Fatalf("Bad")
}
if !strListSubset(parent, parent) {
t.Fatalf("Bad")
}
if !strListSubset(child, child) {
t.Fatalf("Bad")
}
if !strListSubset(child, nil) {
t.Fatalf("Bad")
}
if strListSubset(child, parent) {
t.Fatalf("Bad")
}
if strListSubset(nil, child) {
t.Fatalf("Bad")
}
}

View file

@ -6,13 +6,21 @@ import (
"fmt"
)
// GenerateRandomBytes is used to generate random bytes of given size.
func GenerateRandomBytes(size int) ([]byte, error) {
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
return nil, fmt.Errorf("failed to read random bytes: %v", err)
}
return buf, nil
}
// GenerateUUID is used to generate a random UUID
func GenerateUUID() (string, error) {
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("failed to read random bytes: %v", err)
buf, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return FormatUUID(buf)
}

View file

@ -247,7 +247,7 @@ of the header should be "X-Vault-Token" and the value should be the token.
</li>
<li>
<span class="param">policies</span>
<span class="param-flags">required</span>
<span class="param-flags">optional</span>
A comma-separated list of policies to set on tokens issued when
authenticating against this CA certificate.
</li>