diff --git a/builtin/credential/app-id/path_login.go b/builtin/credential/app-id/path_login.go index 927088f54..40925817b 100644 --- a/builtin/credential/app-id/path_login.go +++ b/builtin/credential/app-id/path_login.go @@ -8,7 +8,7 @@ import ( "net" "strings" - "github.com/hashicorp/vault/helper/policies" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -123,7 +123,7 @@ func (b *backend) pathLoginRenew( if err != nil { return nil, err } - if !policies.EquivalentPolicies(mapPolicies, req.Auth.Policies) { + if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) { return logical.ErrorResponse("policies do not match"), nil } diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index f187ce0f5..d842bb841 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -137,19 +138,13 @@ func (b *backend) pathCertWrite( name := strings.ToLower(d.Get("name").(string)) certificate := d.Get("certificate").(string) displayName := d.Get("display_name").(string) - policies := strings.Split(d.Get("policies").(string), ",") - for i, p := range policies { - policies[i] = strings.TrimSpace(p) - } + policies := policyutil.ParsePolicies(d.Get("policies").(string)) // Default the display name to the certificate name if not given if displayName == "" { displayName = name } - if len(policies) == 0 { - return logical.ErrorResponse("policies required"), nil - } parsed := parsePEM([]byte(certificate)) if len(parsed) == 0 { return logical.ErrorResponse("failed to parse certificate"), nil diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 348fe332b..00a1dabc1 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -10,7 +10,7 @@ import ( "strings" "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/policies" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew( return nil, nil } - if !policies.EquivalentPolicies(cert.Policies, req.Auth.Policies) { + if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) { return logical.ErrorResponse("policies have changed, not renewing"), nil } diff --git a/builtin/credential/github/path_login.go b/builtin/credential/github/path_login.go index b1175648d..feac25476 100644 --- a/builtin/credential/github/path_login.go +++ b/builtin/credential/github/path_login.go @@ -5,7 +5,7 @@ import ( "net/url" "github.com/google/go-github/github" - "github.com/hashicorp/vault/helper/policies" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew( } else { verifyResp = verifyResponse } - if !policies.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) { + if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) { return logical.ErrorResponse("policies do not match"), nil } diff --git a/builtin/credential/ldap/path_groups.go b/builtin/credential/ldap/path_groups.go index 44a0449e9..c13ead699 100644 --- a/builtin/credential/ldap/path_groups.go +++ b/builtin/credential/ldap/path_groups.go @@ -3,6 +3,7 @@ package ldap import ( "strings" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -25,7 +26,7 @@ func pathGroups(b *backend) *framework.Path { Callbacks: map[logical.Operation]framework.OperationFunc{ logical.DeleteOperation: b.pathGroupDelete, logical.ReadOperation: b.pathGroupRead, - logical.UpdateOperation: b.pathGroupWrite, + logical.UpdateOperation: b.pathGroupWrite, }, HelpSynopsis: pathGroupHelpSyn, @@ -79,15 +80,9 @@ func (b *backend) pathGroupRead( func (b *backend) pathGroupWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - name := d.Get("name").(string) - policies := strings.Split(d.Get("policies").(string), ",") - for i, p := range policies { - policies[i] = strings.TrimSpace(p) - } - // Store it - entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{ - Policies: policies, + entry, err := logical.StorageEntryJSON("group/"+d.Get("name").(string), &GroupEntry{ + Policies: policyutil.ParsePolicies(d.Get("policies").(string)), }) if err != nil { return nil, err diff --git a/builtin/credential/ldap/path_login.go b/builtin/credential/ldap/path_login.go index 0f4e67835..b5b8ad7b5 100644 --- a/builtin/credential/ldap/path_login.go +++ b/builtin/credential/ldap/path_login.go @@ -4,7 +4,7 @@ import ( "sort" "strings" - "github.com/hashicorp/vault/helper/policies" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -82,7 +82,7 @@ func (b *backend) pathLoginRenew( return resp, err } - if !policies.EquivalentPolicies(loginPolicies, req.Auth.Policies) { + if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) { return logical.ErrorResponse("policies have changed, not renewing"), nil } diff --git a/builtin/credential/userpass/path_login.go b/builtin/credential/userpass/path_login.go index 088ac8ffe..e7c98a8af 100644 --- a/builtin/credential/userpass/path_login.go +++ b/builtin/credential/userpass/path_login.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/policies" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "golang.org/x/crypto/bcrypt" @@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew( return nil, nil } - if !policies.EquivalentPolicies(user.Policies, req.Auth.Policies) { + if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) { return logical.ErrorResponse("policies have changed, not renewing"), nil } diff --git a/builtin/credential/userpass/path_user_policies.go b/builtin/credential/userpass/path_user_policies.go index 73b9fe6d6..9b586c619 100644 --- a/builtin/credential/userpass/path_user_policies.go +++ b/builtin/credential/userpass/path_user_policies.go @@ -2,8 +2,8 @@ package userpass import ( "fmt" - "strings" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -52,23 +52,11 @@ func (b *backend) pathUserPoliciesUpdate( return nil, fmt.Errorf("username does not exist") } - err = b.updateUserPolicies(req, d, userEntry) - if err != nil { - return nil, err - } + userEntry.Policies = policyutil.ParsePolicies(d.Get("policies").(string)) return nil, b.setUser(req.Storage, username, userEntry) } -func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error { - policies := strings.Split(d.Get("policies").(string), ",") - for i, p := range policies { - policies[i] = strings.TrimSpace(p) - } - userEntry.Policies = policies - return nil -} - const pathUserPoliciesHelpSyn = ` Update the policies associated with the username. ` diff --git a/builtin/credential/userpass/path_users.go b/builtin/credential/userpass/path_users.go index c4873b5e1..8a2f67edf 100644 --- a/builtin/credential/userpass/path_users.go +++ b/builtin/credential/userpass/path_users.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -137,11 +138,8 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData) } } - if _, ok := d.GetOk("policies"); ok { - err = b.updateUserPolicies(req, d, userEntry) - if err != nil { - return nil, err - } + if policiesRaw, ok := d.GetOk("policies"); ok { + userEntry.Policies = policyutil.ParsePolicies(policiesRaw.(string)) } ttlStr := userEntry.TTL.String() diff --git a/helper/policyutil/policyutil.go b/helper/policyutil/policyutil.go new file mode 100644 index 000000000..31f2674d3 --- /dev/null +++ b/helper/policyutil/policyutil.go @@ -0,0 +1,87 @@ +package policyutil + +import ( + "sort" + "strings" +) + +func ParsePolicies(policiesRaw string) []string { + policies := strings.Split(policiesRaw, ",") + defaultFound := false + for i, p := range policies { + policies[i] = strings.TrimSpace(p) + // If 'root' policy is present, ignore all other policies. + if policies[i] == "root" { + policies = []string{"root"} + defaultFound = true + break + } + if policies[i] == "default" { + defaultFound = true + } + } + + // Always add 'default' except only if the policies contain 'root'. + if len(policies) == 0 || !defaultFound { + policies = append(policies, "default") + } + + // Sort to make the computations on policies consistent. + sort.Strings(policies) + + return policies +} + +// ComparePolicies checks whether the given policy sets are equivalent, as in, +// they contain the same values. The benefit of this method is that it leaves +// the "default" policy out of its comparisons as it may be added later by core +// after a set of policies has been saved by a backend. +func EquivalentPolicies(a, b []string) bool { + if a == nil && b == nil { + return true + } + + if a == nil || b == nil { + return false + } + + // First we'll build maps to ensure unique values and filter default + mapA := map[string]bool{} + mapB := map[string]bool{} + for _, keyA := range a { + if keyA == "default" { + continue + } + mapA[keyA] = true + } + for _, keyB := range b { + if keyB == "default" { + continue + } + mapB[keyB] = true + } + + // Now we'll build our checking slices + var sortedA, sortedB []string + for keyA, _ := range mapA { + sortedA = append(sortedA, keyA) + } + for keyB, _ := range mapB { + sortedB = append(sortedB, keyB) + } + sort.Strings(sortedA) + sort.Strings(sortedB) + + // Finally, compare + if len(sortedA) != len(sortedB) { + return false + } + + for i := range sortedA { + if sortedA[i] != sortedB[i] { + return false + } + } + + return true +} diff --git a/helper/policyutil/policyutil_test.go b/helper/policyutil/policyutil_test.go new file mode 100644 index 000000000..8273611f9 --- /dev/null +++ b/helper/policyutil/policyutil_test.go @@ -0,0 +1,61 @@ +package policyutil + +import "testing" + +func TestParsePolicies(t *testing.T) { + expected := []string{"foo", "bar", "default"} + actual := ParsePolicies("foo,bar") + // add default if not present. + if !EquivalentPolicies(expected, actual) { + t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual) + } + + // do not add default more than once. + actual = ParsePolicies("foo,bar,default") + if !EquivalentPolicies(expected, actual) { + t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual) + } + + // handle spaces and tabs. + actual = ParsePolicies(" foo , bar , default") + if !EquivalentPolicies(expected, actual) { + t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual) + } + + // ignore all others if root is present. + expected = []string{"root"} + actual = ParsePolicies("foo,bar,root") + if !EquivalentPolicies(expected, actual) { + t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual) + } + + // with spaces and tabs. + expected = []string{"root"} + actual = ParsePolicies("foo ,bar, root ") + if !EquivalentPolicies(expected, actual) { + t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual) + } +} + +func TestEquivalentPolicies(t *testing.T) { + a := []string{"foo", "bar"} + var b []string + if EquivalentPolicies(a, b) { + t.Fatal("bad") + } + + b = []string{"foo"} + if EquivalentPolicies(a, b) { + t.Fatal("bad") + } + + b = []string{"bar", "foo"} + if !EquivalentPolicies(a, b) { + t.Fatal("bad") + } + + b = []string{"foo", "default", "bar"} + if !EquivalentPolicies(a, b) { + t.Fatal("bad") + } +} diff --git a/helper/strutil/strutil.go b/helper/strutil/strutil.go new file mode 100644 index 000000000..de558e8cf --- /dev/null +++ b/helper/strutil/strutil.go @@ -0,0 +1,22 @@ +package strutil + +// StrListContains looks for a string in a list of strings. +func StrListContains(haystack []string, needle string) bool { + for _, item := range haystack { + if item == needle { + return true + } + } + return false +} + +// StrListSubset checks if a given list is a subset +// of another set +func StrListSubset(super, sub []string) bool { + for _, item := range sub { + if !StrListContains(super, item) { + return false + } + } + return true +} diff --git a/helper/strutil/strutil_test.go b/helper/strutil/strutil_test.go new file mode 100644 index 000000000..d6dced7c4 --- /dev/null +++ b/helper/strutil/strutil_test.go @@ -0,0 +1,49 @@ +package strutil + +import "testing" + +func TestStrListContains(t *testing.T) { + haystack := []string{ + "dev", + "ops", + "prod", + "root", + } + if StrListContains(haystack, "tubez") { + t.Fatalf("Bad") + } + if !StrListContains(haystack, "root") { + t.Fatalf("Bad") + } +} + +func TestStrListSubset(t *testing.T) { + parent := []string{ + "dev", + "ops", + "prod", + "root", + } + child := []string{ + "prod", + "ops", + } + if !StrListSubset(parent, child) { + t.Fatalf("Bad") + } + if !StrListSubset(parent, parent) { + t.Fatalf("Bad") + } + if !StrListSubset(child, child) { + t.Fatalf("Bad") + } + if !StrListSubset(child, nil) { + t.Fatalf("Bad") + } + if StrListSubset(child, parent) { + t.Fatalf("Bad") + } + if StrListSubset(nil, child) { + t.Fatalf("Bad") + } +} diff --git a/vault/core.go b/vault/core.go index b23894ccd..b765ea2f3 100644 --- a/vault/core.go +++ b/vault/core.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/mlock" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/shamir" @@ -595,7 +596,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log } // Set the default lease if non-provided, root tokens are exempt - if auth.TTL == 0 && !strListContains(auth.Policies, "root") { + if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") { auth.TTL = sysView.DefaultLeaseTTL() } @@ -614,7 +615,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log TTL: auth.TTL, } - if strListSubset(te.Policies, []string{"root"}) { + if strutil.StrListSubset(te.Policies, []string{"root"}) { te.Policies = []string{"root"} } else { // Use a map to filter out/prevent duplicates diff --git a/vault/token_store.go b/vault/token_store.go index 92b467748..40b4c21d1 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -12,6 +12,7 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/mitchellh/mapstructure" @@ -889,7 +890,7 @@ func (ts *TokenStore) handleCreateCommon( if len(data.Policies) == 0 { data.Policies = role.AllowedPolicies } else { - if !strListSubset(role.AllowedPolicies, data.Policies) { + if !strutil.StrListSubset(role.AllowedPolicies, data.Policies) { return logical.ErrorResponse("token policies must be subset of the role's allowed policies"), logical.ErrInvalidRequest } } @@ -899,7 +900,7 @@ func (ts *TokenStore) handleCreateCommon( // When a role is not in use, only permit policies to be a subset unless // the client has root or sudo privileges - case !isSudo && !strListSubset(parent.Policies, data.Policies): + case !isSudo && !strutil.StrListSubset(parent.Policies, data.Policies): return logical.ErrorResponse("child policies must be subset of parent"), logical.ErrInvalidRequest } @@ -972,7 +973,7 @@ func (ts *TokenStore) handleCreateCommon( sysView := ts.System() // Set the default lease if non-provided, root tokens are exempt - if te.TTL == 0 && !strListContains(te.Policies, "root") { + if te.TTL == 0 && !strutil.StrListContains(te.Policies, "root") { te.TTL = sysView.DefaultLeaseTTL() } diff --git a/vault/util_test.go b/vault/util_test.go index f7b65d069..70fe1d78d 100644 --- a/vault/util_test.go +++ b/vault/util_test.go @@ -16,49 +16,3 @@ func TestRandBytes(t *testing.T) { t.Fatalf("bad: %v", b) } } - -func TestStrListContains(t *testing.T) { - haystack := []string{ - "dev", - "ops", - "prod", - "root", - } - if strListContains(haystack, "tubez") { - t.Fatalf("Bad") - } - if !strListContains(haystack, "root") { - t.Fatalf("Bad") - } -} - -func TestStrListSubset(t *testing.T) { - parent := []string{ - "dev", - "ops", - "prod", - "root", - } - child := []string{ - "prod", - "ops", - } - if !strListSubset(parent, child) { - t.Fatalf("Bad") - } - if !strListSubset(parent, parent) { - t.Fatalf("Bad") - } - if !strListSubset(child, child) { - t.Fatalf("Bad") - } - if !strListSubset(child, nil) { - t.Fatalf("Bad") - } - if strListSubset(child, parent) { - t.Fatalf("Bad") - } - if strListSubset(nil, child) { - t.Fatalf("Bad") - } -} diff --git a/vendor/github.com/hashicorp/go-uuid/uuid.go b/vendor/github.com/hashicorp/go-uuid/uuid.go index 322b522c2..ff9364c40 100644 --- a/vendor/github.com/hashicorp/go-uuid/uuid.go +++ b/vendor/github.com/hashicorp/go-uuid/uuid.go @@ -6,13 +6,21 @@ import ( "fmt" ) +// GenerateRandomBytes is used to generate random bytes of given size. +func GenerateRandomBytes(size int) ([]byte, error) { + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return nil, fmt.Errorf("failed to read random bytes: %v", err) + } + return buf, nil +} + // GenerateUUID is used to generate a random UUID func GenerateUUID() (string, error) { - buf := make([]byte, 16) - if _, err := rand.Read(buf); err != nil { - return "", fmt.Errorf("failed to read random bytes: %v", err) + buf, err := GenerateRandomBytes(16) + if err != nil { + return "", err } - return FormatUUID(buf) } diff --git a/website/source/docs/auth/cert.html.md b/website/source/docs/auth/cert.html.md index 7147f7ac8..c443796b9 100644 --- a/website/source/docs/auth/cert.html.md +++ b/website/source/docs/auth/cert.html.md @@ -247,7 +247,7 @@ of the header should be "X-Vault-Token" and the value should be the token.