Update existing alias metadata during authentication (#6068)

This commit is contained in:
Jim Kalafut 2019-01-23 08:26:50 -08:00 committed by GitHub
parent aac271ed7f
commit f097b8d934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 179 additions and 26 deletions

View File

@ -301,6 +301,24 @@ func EquivalentSlices(a, b []string) bool {
return true
}
// EqualStringMaps tests whether two map[string]string objects are equal.
// Equal means both maps have the same sets of keys and values. This function
// is 6-10x faster than a call to reflect.DeepEqual().
func EqualStringMaps(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k := range a {
v, ok := b[k]
if !ok || a[k] != v {
return false
}
}
return true
}
// StrListDelete removes the first occurrence of the given item from the slice
// of strings if the item exists.
func StrListDelete(s []string, d string) []string {

View File

@ -531,3 +531,44 @@ func TestDifference(t *testing.T) {
})
}
}
func TestStrUtil_EqualStringMaps(t *testing.T) {
m1 := map[string]string{
"foo": "a",
}
m2 := map[string]string{
"foo": "a",
"bar": "b",
}
var m3 map[string]string
m4 := map[string]string{
"dog": "",
}
m5 := map[string]string{
"cat": "",
}
tests := []struct {
a map[string]string
b map[string]string
result bool
}{
{m1, m1, true},
{m2, m2, true},
{m1, m2, false},
{m2, m1, false},
{m2, m2, true},
{m3, m1, false},
{m3, m3, true},
{m4, m5, false},
}
for i, test := range tests {
actual := EqualStringMaps(test.a, test.b)
if actual != test.result {
t.Fatalf("case %d, expected %v, got %v", i, test.result, actual)
}
}
}

View File

@ -13,6 +13,7 @@ import (
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -406,6 +407,7 @@ func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor,
func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.Alias) (*identity.Entity, error) {
var entity *identity.Entity
var err error
var update bool
if alias == nil {
return nil, fmt.Errorf("alias is nil")
@ -428,12 +430,12 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
return nil, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType)
}
// Check if an entity already exists for the given alais
// Check if an entity already exists for the given alias
entity, err = i.entityByAliasFactors(alias.MountAccessor, alias.Name, false)
if err != nil {
return nil, err
}
if entity != nil {
if entity != nil && changedAliasIndex(entity, alias) == -1 {
return entity, nil
}
@ -445,14 +447,23 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
defer txn.Abort()
// Check if an entity was created before acquiring the lock
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, true)
if err != nil {
return nil, err
}
if entity != nil {
idx := changedAliasIndex(entity, alias)
if idx == -1 {
return entity, nil
}
a := entity.Aliases[idx]
a.Metadata = alias.Metadata
a.LastUpdateTime = ptypes.TimestampNow()
update = true
}
if !update {
entity = new(identity.Entity)
err = i.sanitizeEntity(ctx, entity)
if err != nil {
@ -480,6 +491,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
entity.Aliases = []*identity.Alias{
newAlias,
}
}
// Update MemDB and persist entity object
err = i.upsertEntityInTxn(ctx, txn, entity, nil, true)
@ -491,3 +503,17 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
return entity, nil
}
// changedAliasIndex searches an entity for changed alias metadata.
//
// If a match is found, the changed alias's index is returned. If no alias
// names match or no metadata is different, -1 is returned.
func changedAliasIndex(entity *identity.Entity, alias *logical.Alias) int {
for i, a := range entity.Aliases {
if a.Name == alias.Name && !strutil.EqualStringMaps(a.Metadata, alias.Metadata) {
return i
}
}
return -1
}

View File

@ -96,6 +96,9 @@ func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
MountType: "github",
MountAccessor: ghAccessor,
Name: "githubuser",
Metadata: map[string]string{
"foo": "a",
},
}
entity, err := is.CreateOrFetchEntity(ctx, alias)
@ -129,6 +132,71 @@ func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
if entity.Aliases[0].Name != alias.Name {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
if diff := deep.Equal(entity.Aliases[0].Metadata, map[string]string{"foo": "a"}); diff != nil {
t.Fatal(diff)
}
// Add a new alias to the entity and verify its existence
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity-alias",
Data: map[string]interface{}{
"name": "githubuser2",
"canonical_id": entity.ID,
"mount_accessor": ghAccessor,
},
}
resp, err := is.HandleRequest(ctx, registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entity, err = is.CreateOrFetchEntity(ctx, alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}
if len(entity.Aliases) != 2 {
t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(entity.Aliases))
}
if entity.Aliases[1].Name != "githubuser2" {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, "githubuser2")
}
if diff := deep.Equal(entity.Aliases[1].Metadata, map[string]string(nil)); diff != nil {
t.Fatal(diff)
}
// Change the metadata of an existing alias and verify that
// a the change takes effect only for the target alias.
alias.Metadata = map[string]string{
"foo": "zzzz",
}
entity, err = is.CreateOrFetchEntity(ctx, alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}
if len(entity.Aliases) != 2 {
t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(entity.Aliases))
}
if diff := deep.Equal(entity.Aliases[0].Metadata, map[string]string{"foo": "zzzz"}); diff != nil {
t.Fatal(diff)
}
if diff := deep.Equal(entity.Aliases[1].Metadata, map[string]string(nil)); diff != nil {
t.Fatal(diff)
}
}
func TestIdentityStore_EntityByAliasFactors(t *testing.T) {