Merge pull request #1299 from hashicorp/utility-enhancements
Utility Enhancements
This commit is contained in:
commit
dbc8162ae4
|
@ -8,7 +8,7 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policies"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -123,7 +123,7 @@ func (b *backend) pathLoginRenew(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !policies.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
|
||||
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies do not match"), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -137,19 +138,13 @@ func (b *backend) pathCertWrite(
|
|||
name := strings.ToLower(d.Get("name").(string))
|
||||
certificate := d.Get("certificate").(string)
|
||||
displayName := d.Get("display_name").(string)
|
||||
policies := strings.Split(d.Get("policies").(string), ",")
|
||||
for i, p := range policies {
|
||||
policies[i] = strings.TrimSpace(p)
|
||||
}
|
||||
policies := policyutil.ParsePolicies(d.Get("policies").(string))
|
||||
|
||||
// Default the display name to the certificate name if not given
|
||||
if displayName == "" {
|
||||
displayName = name
|
||||
}
|
||||
|
||||
if len(policies) == 0 {
|
||||
return logical.ErrorResponse("policies required"), nil
|
||||
}
|
||||
parsed := parsePEM([]byte(certificate))
|
||||
if len(parsed) == 0 {
|
||||
return logical.ErrorResponse("failed to parse certificate"), nil
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/policies"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
if !policies.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
|
||||
if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/url"
|
||||
|
||||
"github.com/google/go-github/github"
|
||||
"github.com/hashicorp/vault/helper/policies"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
|
|||
} else {
|
||||
verifyResp = verifyResponse
|
||||
}
|
||||
if !policies.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
|
||||
if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies do not match"), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package ldap
|
|||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -25,7 +26,7 @@ func pathGroups(b *backend) *framework.Path {
|
|||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.DeleteOperation: b.pathGroupDelete,
|
||||
logical.ReadOperation: b.pathGroupRead,
|
||||
logical.UpdateOperation: b.pathGroupWrite,
|
||||
logical.UpdateOperation: b.pathGroupWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathGroupHelpSyn,
|
||||
|
@ -79,15 +80,9 @@ func (b *backend) pathGroupRead(
|
|||
|
||||
func (b *backend) pathGroupWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
policies := strings.Split(d.Get("policies").(string), ",")
|
||||
for i, p := range policies {
|
||||
policies[i] = strings.TrimSpace(p)
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
|
||||
Policies: policies,
|
||||
entry, err := logical.StorageEntryJSON("group/"+d.Get("name").(string), &GroupEntry{
|
||||
Policies: policyutil.ParsePolicies(d.Get("policies").(string)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policies"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew(
|
|||
return resp, err
|
||||
}
|
||||
|
||||
if !policies.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
|
||||
if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policies"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
if !policies.EquivalentPolicies(user.Policies, req.Auth.Policies) {
|
||||
if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@ package userpass
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -52,23 +52,11 @@ func (b *backend) pathUserPoliciesUpdate(
|
|||
return nil, fmt.Errorf("username does not exist")
|
||||
}
|
||||
|
||||
err = b.updateUserPolicies(req, d, userEntry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userEntry.Policies = policyutil.ParsePolicies(d.Get("policies").(string))
|
||||
|
||||
return nil, b.setUser(req.Storage, username, userEntry)
|
||||
}
|
||||
|
||||
func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
|
||||
policies := strings.Split(d.Get("policies").(string), ",")
|
||||
for i, p := range policies {
|
||||
policies[i] = strings.TrimSpace(p)
|
||||
}
|
||||
userEntry.Policies = policies
|
||||
return nil
|
||||
}
|
||||
|
||||
const pathUserPoliciesHelpSyn = `
|
||||
Update the policies associated with the username.
|
||||
`
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -137,11 +138,8 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData)
|
|||
}
|
||||
}
|
||||
|
||||
if _, ok := d.GetOk("policies"); ok {
|
||||
err = b.updateUserPolicies(req, d, userEntry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policiesRaw, ok := d.GetOk("policies"); ok {
|
||||
userEntry.Policies = policyutil.ParsePolicies(policiesRaw.(string))
|
||||
}
|
||||
|
||||
ttlStr := userEntry.TTL.String()
|
||||
|
|
87
helper/policyutil/policyutil.go
Normal file
87
helper/policyutil/policyutil.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package policyutil
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ParsePolicies(policiesRaw string) []string {
|
||||
policies := strings.Split(policiesRaw, ",")
|
||||
defaultFound := false
|
||||
for i, p := range policies {
|
||||
policies[i] = strings.TrimSpace(p)
|
||||
// If 'root' policy is present, ignore all other policies.
|
||||
if policies[i] == "root" {
|
||||
policies = []string{"root"}
|
||||
defaultFound = true
|
||||
break
|
||||
}
|
||||
if policies[i] == "default" {
|
||||
defaultFound = true
|
||||
}
|
||||
}
|
||||
|
||||
// Always add 'default' except only if the policies contain 'root'.
|
||||
if len(policies) == 0 || !defaultFound {
|
||||
policies = append(policies, "default")
|
||||
}
|
||||
|
||||
// Sort to make the computations on policies consistent.
|
||||
sort.Strings(policies)
|
||||
|
||||
return policies
|
||||
}
|
||||
|
||||
// ComparePolicies checks whether the given policy sets are equivalent, as in,
|
||||
// they contain the same values. The benefit of this method is that it leaves
|
||||
// the "default" policy out of its comparisons as it may be added later by core
|
||||
// after a set of policies has been saved by a backend.
|
||||
func EquivalentPolicies(a, b []string) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// First we'll build maps to ensure unique values and filter default
|
||||
mapA := map[string]bool{}
|
||||
mapB := map[string]bool{}
|
||||
for _, keyA := range a {
|
||||
if keyA == "default" {
|
||||
continue
|
||||
}
|
||||
mapA[keyA] = true
|
||||
}
|
||||
for _, keyB := range b {
|
||||
if keyB == "default" {
|
||||
continue
|
||||
}
|
||||
mapB[keyB] = true
|
||||
}
|
||||
|
||||
// Now we'll build our checking slices
|
||||
var sortedA, sortedB []string
|
||||
for keyA, _ := range mapA {
|
||||
sortedA = append(sortedA, keyA)
|
||||
}
|
||||
for keyB, _ := range mapB {
|
||||
sortedB = append(sortedB, keyB)
|
||||
}
|
||||
sort.Strings(sortedA)
|
||||
sort.Strings(sortedB)
|
||||
|
||||
// Finally, compare
|
||||
if len(sortedA) != len(sortedB) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range sortedA {
|
||||
if sortedA[i] != sortedB[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
61
helper/policyutil/policyutil_test.go
Normal file
61
helper/policyutil/policyutil_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package policyutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParsePolicies(t *testing.T) {
|
||||
expected := []string{"foo", "bar", "default"}
|
||||
actual := ParsePolicies("foo,bar")
|
||||
// add default if not present.
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
|
||||
// do not add default more than once.
|
||||
actual = ParsePolicies("foo,bar,default")
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
|
||||
// handle spaces and tabs.
|
||||
actual = ParsePolicies(" foo , bar , default")
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
|
||||
// ignore all others if root is present.
|
||||
expected = []string{"root"}
|
||||
actual = ParsePolicies("foo,bar,root")
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
|
||||
// with spaces and tabs.
|
||||
expected = []string{"root"}
|
||||
actual = ParsePolicies("foo ,bar, root ")
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEquivalentPolicies(t *testing.T) {
|
||||
a := []string{"foo", "bar"}
|
||||
var b []string
|
||||
if EquivalentPolicies(a, b) {
|
||||
t.Fatal("bad")
|
||||
}
|
||||
|
||||
b = []string{"foo"}
|
||||
if EquivalentPolicies(a, b) {
|
||||
t.Fatal("bad")
|
||||
}
|
||||
|
||||
b = []string{"bar", "foo"}
|
||||
if !EquivalentPolicies(a, b) {
|
||||
t.Fatal("bad")
|
||||
}
|
||||
|
||||
b = []string{"foo", "default", "bar"}
|
||||
if !EquivalentPolicies(a, b) {
|
||||
t.Fatal("bad")
|
||||
}
|
||||
}
|
22
helper/strutil/strutil.go
Normal file
22
helper/strutil/strutil.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package strutil
|
||||
|
||||
// StrListContains looks for a string in a list of strings.
|
||||
func StrListContains(haystack []string, needle string) bool {
|
||||
for _, item := range haystack {
|
||||
if item == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StrListSubset checks if a given list is a subset
|
||||
// of another set
|
||||
func StrListSubset(super, sub []string) bool {
|
||||
for _, item := range sub {
|
||||
if !StrListContains(super, item) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
49
helper/strutil/strutil_test.go
Normal file
49
helper/strutil/strutil_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package strutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStrListContains(t *testing.T) {
|
||||
haystack := []string{
|
||||
"dev",
|
||||
"ops",
|
||||
"prod",
|
||||
"root",
|
||||
}
|
||||
if StrListContains(haystack, "tubez") {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !StrListContains(haystack, "root") {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrListSubset(t *testing.T) {
|
||||
parent := []string{
|
||||
"dev",
|
||||
"ops",
|
||||
"prod",
|
||||
"root",
|
||||
}
|
||||
child := []string{
|
||||
"prod",
|
||||
"ops",
|
||||
}
|
||||
if !StrListSubset(parent, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !StrListSubset(parent, parent) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !StrListSubset(child, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !StrListSubset(child, nil) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if StrListSubset(child, parent) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if StrListSubset(nil, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/mlock"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
"github.com/hashicorp/vault/shamir"
|
||||
|
@ -595,7 +596,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
}
|
||||
|
||||
// Set the default lease if non-provided, root tokens are exempt
|
||||
if auth.TTL == 0 && !strListContains(auth.Policies, "root") {
|
||||
if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") {
|
||||
auth.TTL = sysView.DefaultLeaseTTL()
|
||||
}
|
||||
|
||||
|
@ -614,7 +615,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
TTL: auth.TTL,
|
||||
}
|
||||
|
||||
if strListSubset(te.Policies, []string{"root"}) {
|
||||
if strutil.StrListSubset(te.Policies, []string{"root"}) {
|
||||
te.Policies = []string{"root"}
|
||||
} else {
|
||||
// Use a map to filter out/prevent duplicates
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -889,7 +890,7 @@ func (ts *TokenStore) handleCreateCommon(
|
|||
if len(data.Policies) == 0 {
|
||||
data.Policies = role.AllowedPolicies
|
||||
} else {
|
||||
if !strListSubset(role.AllowedPolicies, data.Policies) {
|
||||
if !strutil.StrListSubset(role.AllowedPolicies, data.Policies) {
|
||||
return logical.ErrorResponse("token policies must be subset of the role's allowed policies"), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
@ -899,7 +900,7 @@ func (ts *TokenStore) handleCreateCommon(
|
|||
|
||||
// When a role is not in use, only permit policies to be a subset unless
|
||||
// the client has root or sudo privileges
|
||||
case !isSudo && !strListSubset(parent.Policies, data.Policies):
|
||||
case !isSudo && !strutil.StrListSubset(parent.Policies, data.Policies):
|
||||
return logical.ErrorResponse("child policies must be subset of parent"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
|
@ -972,7 +973,7 @@ func (ts *TokenStore) handleCreateCommon(
|
|||
sysView := ts.System()
|
||||
|
||||
// Set the default lease if non-provided, root tokens are exempt
|
||||
if te.TTL == 0 && !strListContains(te.Policies, "root") {
|
||||
if te.TTL == 0 && !strutil.StrListContains(te.Policies, "root") {
|
||||
te.TTL = sysView.DefaultLeaseTTL()
|
||||
}
|
||||
|
||||
|
|
|
@ -16,49 +16,3 @@ func TestRandBytes(t *testing.T) {
|
|||
t.Fatalf("bad: %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrListContains(t *testing.T) {
|
||||
haystack := []string{
|
||||
"dev",
|
||||
"ops",
|
||||
"prod",
|
||||
"root",
|
||||
}
|
||||
if strListContains(haystack, "tubez") {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !strListContains(haystack, "root") {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrListSubset(t *testing.T) {
|
||||
parent := []string{
|
||||
"dev",
|
||||
"ops",
|
||||
"prod",
|
||||
"root",
|
||||
}
|
||||
child := []string{
|
||||
"prod",
|
||||
"ops",
|
||||
}
|
||||
if !strListSubset(parent, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !strListSubset(parent, parent) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !strListSubset(child, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if !strListSubset(child, nil) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if strListSubset(child, parent) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
if strListSubset(nil, child) {
|
||||
t.Fatalf("Bad")
|
||||
}
|
||||
}
|
||||
|
|
16
vendor/github.com/hashicorp/go-uuid/uuid.go
generated
vendored
16
vendor/github.com/hashicorp/go-uuid/uuid.go
generated
vendored
|
@ -6,13 +6,21 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// GenerateRandomBytes is used to generate random bytes of given size.
|
||||
func GenerateRandomBytes(size int) ([]byte, error) {
|
||||
buf := make([]byte, size)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to read random bytes: %v", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// GenerateUUID is used to generate a random UUID
|
||||
func GenerateUUID() (string, error) {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %v", err)
|
||||
buf, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return FormatUUID(buf)
|
||||
}
|
||||
|
||||
|
|
|
@ -247,7 +247,7 @@ of the header should be "X-Vault-Token" and the value should be the token.
|
|||
</li>
|
||||
<li>
|
||||
<span class="param">policies</span>
|
||||
<span class="param-flags">required</span>
|
||||
<span class="param-flags">optional</span>
|
||||
A comma-separated list of policies to set on tokens issued when
|
||||
authenticating against this CA certificate.
|
||||
</li>
|
||||
|
|
Loading…
Reference in a new issue