Merge pull request #1299 from hashicorp/utility-enhancements

Utility Enhancements
This commit is contained in:
Vishal Nayak 2016-04-05 20:47:40 -04:00
commit dbc8162ae4
18 changed files with 260 additions and 101 deletions

View File

@ -8,7 +8,7 @@ import (
"net" "net"
"strings" "strings"
"github.com/hashicorp/vault/helper/policies" "github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -123,7 +123,7 @@ func (b *backend) pathLoginRenew(
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !policies.EquivalentPolicies(mapPolicies, req.Auth.Policies) { if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil return logical.ErrorResponse("policies do not match"), nil
} }

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -137,19 +138,13 @@ func (b *backend) pathCertWrite(
name := strings.ToLower(d.Get("name").(string)) name := strings.ToLower(d.Get("name").(string))
certificate := d.Get("certificate").(string) certificate := d.Get("certificate").(string)
displayName := d.Get("display_name").(string) displayName := d.Get("display_name").(string)
policies := strings.Split(d.Get("policies").(string), ",") policies := policyutil.ParsePolicies(d.Get("policies").(string))
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
// Default the display name to the certificate name if not given // Default the display name to the certificate name if not given
if displayName == "" { if displayName == "" {
displayName = name displayName = name
} }
if len(policies) == 0 {
return logical.ErrorResponse("policies required"), nil
}
parsed := parsePEM([]byte(certificate)) parsed := parsePEM([]byte(certificate))
if len(parsed) == 0 { if len(parsed) == 0 {
return logical.ErrorResponse("failed to parse certificate"), nil return logical.ErrorResponse("failed to parse certificate"), nil

View File

@ -10,7 +10,7 @@ import (
"strings" "strings"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/policies" "github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(
return nil, nil return nil, nil
} }
if !policies.EquivalentPolicies(cert.Policies, req.Auth.Policies) { if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil return logical.ErrorResponse("policies have changed, not renewing"), nil
} }

View File

@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"github.com/google/go-github/github" "github.com/google/go-github/github"
"github.com/hashicorp/vault/helper/policies" "github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
} else { } else {
verifyResp = verifyResponse verifyResp = verifyResponse
} }
if !policies.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) { if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil return logical.ErrorResponse("policies do not match"), nil
} }

View File

@ -3,6 +3,7 @@ package ldap
import ( import (
"strings" "strings"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -25,7 +26,7 @@ func pathGroups(b *backend) *framework.Path {
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.DeleteOperation: b.pathGroupDelete, logical.DeleteOperation: b.pathGroupDelete,
logical.ReadOperation: b.pathGroupRead, logical.ReadOperation: b.pathGroupRead,
logical.UpdateOperation: b.pathGroupWrite, logical.UpdateOperation: b.pathGroupWrite,
}, },
HelpSynopsis: pathGroupHelpSyn, HelpSynopsis: pathGroupHelpSyn,
@ -79,15 +80,9 @@ func (b *backend) pathGroupRead(
func (b *backend) pathGroupWrite( func (b *backend) pathGroupWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
// Store it // Store it
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{ entry, err := logical.StorageEntryJSON("group/"+d.Get("name").(string), &GroupEntry{
Policies: policies, Policies: policyutil.ParsePolicies(d.Get("policies").(string)),
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -4,7 +4,7 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/hashicorp/vault/helper/policies" "github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
return resp, err return resp, err
} }
if !policies.EquivalentPolicies(loginPolicies, req.Auth.Policies) { if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil return logical.ErrorResponse("policies have changed, not renewing"), nil
} }

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/vault/helper/policies" "github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew(
return nil, nil return nil, nil
} }
if !policies.EquivalentPolicies(user.Policies, req.Auth.Policies) { if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil return logical.ErrorResponse("policies have changed, not renewing"), nil
} }

View File

@ -2,8 +2,8 @@ package userpass
import ( import (
"fmt" "fmt"
"strings"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -52,23 +52,11 @@ func (b *backend) pathUserPoliciesUpdate(
return nil, fmt.Errorf("username does not exist") return nil, fmt.Errorf("username does not exist")
} }
err = b.updateUserPolicies(req, d, userEntry) userEntry.Policies = policyutil.ParsePolicies(d.Get("policies").(string))
if err != nil {
return nil, err
}
return nil, b.setUser(req.Storage, username, userEntry) return nil, b.setUser(req.Storage, username, userEntry)
} }
func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
userEntry.Policies = policies
return nil
}
const pathUserPoliciesHelpSyn = ` const pathUserPoliciesHelpSyn = `
Update the policies associated with the username. Update the policies associated with the username.
` `

View File

@ -5,6 +5,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -137,11 +138,8 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData)
} }
} }
if _, ok := d.GetOk("policies"); ok { if policiesRaw, ok := d.GetOk("policies"); ok {
err = b.updateUserPolicies(req, d, userEntry) userEntry.Policies = policyutil.ParsePolicies(policiesRaw.(string))
if err != nil {
return nil, err
}
} }
ttlStr := userEntry.TTL.String() ttlStr := userEntry.TTL.String()

View File

@ -0,0 +1,87 @@
package policyutil
import (
"sort"
"strings"
)
func ParsePolicies(policiesRaw string) []string {
policies := strings.Split(policiesRaw, ",")
defaultFound := false
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
// If 'root' policy is present, ignore all other policies.
if policies[i] == "root" {
policies = []string{"root"}
defaultFound = true
break
}
if policies[i] == "default" {
defaultFound = true
}
}
// Always add 'default' except only if the policies contain 'root'.
if len(policies) == 0 || !defaultFound {
policies = append(policies, "default")
}
// Sort to make the computations on policies consistent.
sort.Strings(policies)
return policies
}
// ComparePolicies checks whether the given policy sets are equivalent, as in,
// they contain the same values. The benefit of this method is that it leaves
// the "default" policy out of its comparisons as it may be added later by core
// after a set of policies has been saved by a backend.
func EquivalentPolicies(a, b []string) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// First we'll build maps to ensure unique values and filter default
mapA := map[string]bool{}
mapB := map[string]bool{}
for _, keyA := range a {
if keyA == "default" {
continue
}
mapA[keyA] = true
}
for _, keyB := range b {
if keyB == "default" {
continue
}
mapB[keyB] = true
}
// Now we'll build our checking slices
var sortedA, sortedB []string
for keyA, _ := range mapA {
sortedA = append(sortedA, keyA)
}
for keyB, _ := range mapB {
sortedB = append(sortedB, keyB)
}
sort.Strings(sortedA)
sort.Strings(sortedB)
// Finally, compare
if len(sortedA) != len(sortedB) {
return false
}
for i := range sortedA {
if sortedA[i] != sortedB[i] {
return false
}
}
return true
}

View File

@ -0,0 +1,61 @@
package policyutil
import "testing"
func TestParsePolicies(t *testing.T) {
expected := []string{"foo", "bar", "default"}
actual := ParsePolicies("foo,bar")
// add default if not present.
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// do not add default more than once.
actual = ParsePolicies("foo,bar,default")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// handle spaces and tabs.
actual = ParsePolicies(" foo , bar , default")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// ignore all others if root is present.
expected = []string{"root"}
actual = ParsePolicies("foo,bar,root")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// with spaces and tabs.
expected = []string{"root"}
actual = ParsePolicies("foo ,bar, root ")
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
}
func TestEquivalentPolicies(t *testing.T) {
a := []string{"foo", "bar"}
var b []string
if EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"foo"}
if EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"bar", "foo"}
if !EquivalentPolicies(a, b) {
t.Fatal("bad")
}
b = []string{"foo", "default", "bar"}
if !EquivalentPolicies(a, b) {
t.Fatal("bad")
}
}

22
helper/strutil/strutil.go Normal file
View File

@ -0,0 +1,22 @@
package strutil
// StrListContains looks for a string in a list of strings.
func StrListContains(haystack []string, needle string) bool {
for _, item := range haystack {
if item == needle {
return true
}
}
return false
}
// StrListSubset checks if a given list is a subset
// of another set
func StrListSubset(super, sub []string) bool {
for _, item := range sub {
if !StrListContains(super, item) {
return false
}
}
return true
}

View File

@ -0,0 +1,49 @@
package strutil
import "testing"
func TestStrListContains(t *testing.T) {
haystack := []string{
"dev",
"ops",
"prod",
"root",
}
if StrListContains(haystack, "tubez") {
t.Fatalf("Bad")
}
if !StrListContains(haystack, "root") {
t.Fatalf("Bad")
}
}
func TestStrListSubset(t *testing.T) {
parent := []string{
"dev",
"ops",
"prod",
"root",
}
child := []string{
"prod",
"ops",
}
if !StrListSubset(parent, child) {
t.Fatalf("Bad")
}
if !StrListSubset(parent, parent) {
t.Fatalf("Bad")
}
if !StrListSubset(child, child) {
t.Fatalf("Bad")
}
if !StrListSubset(child, nil) {
t.Fatalf("Bad")
}
if StrListSubset(child, parent) {
t.Fatalf("Bad")
}
if StrListSubset(nil, child) {
t.Fatalf("Bad")
}
}

View File

@ -18,6 +18,7 @@ import (
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/mlock" "github.com/hashicorp/vault/helper/mlock"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/shamir"
@ -595,7 +596,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
} }
// Set the default lease if non-provided, root tokens are exempt // Set the default lease if non-provided, root tokens are exempt
if auth.TTL == 0 && !strListContains(auth.Policies, "root") { if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") {
auth.TTL = sysView.DefaultLeaseTTL() auth.TTL = sysView.DefaultLeaseTTL()
} }
@ -614,7 +615,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
TTL: auth.TTL, TTL: auth.TTL,
} }
if strListSubset(te.Policies, []string{"root"}) { if strutil.StrListSubset(te.Policies, []string{"root"}) {
te.Policies = []string{"root"} te.Policies = []string{"root"}
} else { } else {
// Use a map to filter out/prevent duplicates // Use a map to filter out/prevent duplicates

View File

@ -12,6 +12,7 @@ import (
"github.com/fatih/structs" "github.com/fatih/structs"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
@ -889,7 +890,7 @@ func (ts *TokenStore) handleCreateCommon(
if len(data.Policies) == 0 { if len(data.Policies) == 0 {
data.Policies = role.AllowedPolicies data.Policies = role.AllowedPolicies
} else { } else {
if !strListSubset(role.AllowedPolicies, data.Policies) { if !strutil.StrListSubset(role.AllowedPolicies, data.Policies) {
return logical.ErrorResponse("token policies must be subset of the role's allowed policies"), logical.ErrInvalidRequest return logical.ErrorResponse("token policies must be subset of the role's allowed policies"), logical.ErrInvalidRequest
} }
} }
@ -899,7 +900,7 @@ func (ts *TokenStore) handleCreateCommon(
// When a role is not in use, only permit policies to be a subset unless // When a role is not in use, only permit policies to be a subset unless
// the client has root or sudo privileges // the client has root or sudo privileges
case !isSudo && !strListSubset(parent.Policies, data.Policies): case !isSudo && !strutil.StrListSubset(parent.Policies, data.Policies):
return logical.ErrorResponse("child policies must be subset of parent"), logical.ErrInvalidRequest return logical.ErrorResponse("child policies must be subset of parent"), logical.ErrInvalidRequest
} }
@ -972,7 +973,7 @@ func (ts *TokenStore) handleCreateCommon(
sysView := ts.System() sysView := ts.System()
// Set the default lease if non-provided, root tokens are exempt // Set the default lease if non-provided, root tokens are exempt
if te.TTL == 0 && !strListContains(te.Policies, "root") { if te.TTL == 0 && !strutil.StrListContains(te.Policies, "root") {
te.TTL = sysView.DefaultLeaseTTL() te.TTL = sysView.DefaultLeaseTTL()
} }

View File

@ -16,49 +16,3 @@ func TestRandBytes(t *testing.T) {
t.Fatalf("bad: %v", b) t.Fatalf("bad: %v", b)
} }
} }
func TestStrListContains(t *testing.T) {
haystack := []string{
"dev",
"ops",
"prod",
"root",
}
if strListContains(haystack, "tubez") {
t.Fatalf("Bad")
}
if !strListContains(haystack, "root") {
t.Fatalf("Bad")
}
}
func TestStrListSubset(t *testing.T) {
parent := []string{
"dev",
"ops",
"prod",
"root",
}
child := []string{
"prod",
"ops",
}
if !strListSubset(parent, child) {
t.Fatalf("Bad")
}
if !strListSubset(parent, parent) {
t.Fatalf("Bad")
}
if !strListSubset(child, child) {
t.Fatalf("Bad")
}
if !strListSubset(child, nil) {
t.Fatalf("Bad")
}
if strListSubset(child, parent) {
t.Fatalf("Bad")
}
if strListSubset(nil, child) {
t.Fatalf("Bad")
}
}

View File

@ -6,13 +6,21 @@ import (
"fmt" "fmt"
) )
// GenerateRandomBytes is used to generate random bytes of given size.
func GenerateRandomBytes(size int) ([]byte, error) {
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
return nil, fmt.Errorf("failed to read random bytes: %v", err)
}
return buf, nil
}
// GenerateUUID is used to generate a random UUID // GenerateUUID is used to generate a random UUID
func GenerateUUID() (string, error) { func GenerateUUID() (string, error) {
buf := make([]byte, 16) buf, err := GenerateRandomBytes(16)
if _, err := rand.Read(buf); err != nil { if err != nil {
return "", fmt.Errorf("failed to read random bytes: %v", err) return "", err
} }
return FormatUUID(buf) return FormatUUID(buf)
} }

View File

@ -247,7 +247,7 @@ of the header should be "X-Vault-Token" and the value should be the token.
</li> </li>
<li> <li>
<span class="param">policies</span> <span class="param">policies</span>
<span class="param-flags">required</span> <span class="param-flags">optional</span>
A comma-separated list of policies to set on tokens issued when A comma-separated list of policies to set on tokens issued when
authenticating against this CA certificate. authenticating against this CA certificate.
</li> </li>