Update existing alias metadata during authentication (#6068)
This commit is contained in:
parent
aac271ed7f
commit
f097b8d934
|
@ -301,6 +301,24 @@ func EquivalentSlices(a, b []string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// EqualStringMaps tests whether two map[string]string objects are equal.
|
||||
// Equal means both maps have the same sets of keys and values. This function
|
||||
// is 6-10x faster than a call to reflect.DeepEqual().
|
||||
func EqualStringMaps(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k := range a {
|
||||
v, ok := b[k]
|
||||
if !ok || a[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// StrListDelete removes the first occurrence of the given item from the slice
|
||||
// of strings if the item exists.
|
||||
func StrListDelete(s []string, d string) []string {
|
||||
|
|
|
@ -531,3 +531,44 @@ func TestDifference(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrUtil_EqualStringMaps(t *testing.T) {
|
||||
m1 := map[string]string{
|
||||
"foo": "a",
|
||||
}
|
||||
m2 := map[string]string{
|
||||
"foo": "a",
|
||||
"bar": "b",
|
||||
}
|
||||
var m3 map[string]string
|
||||
|
||||
m4 := map[string]string{
|
||||
"dog": "",
|
||||
}
|
||||
|
||||
m5 := map[string]string{
|
||||
"cat": "",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
a map[string]string
|
||||
b map[string]string
|
||||
result bool
|
||||
}{
|
||||
{m1, m1, true},
|
||||
{m2, m2, true},
|
||||
{m1, m2, false},
|
||||
{m2, m1, false},
|
||||
{m2, m2, true},
|
||||
{m3, m1, false},
|
||||
{m3, m3, true},
|
||||
{m4, m5, false},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual := EqualStringMaps(test.a, test.b)
|
||||
if actual != test.result {
|
||||
t.Fatalf("case %d, expected %v, got %v", i, test.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/hashicorp/vault/helper/identity"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/storagepacker"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -406,6 +407,7 @@ func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor,
|
|||
func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.Alias) (*identity.Entity, error) {
|
||||
var entity *identity.Entity
|
||||
var err error
|
||||
var update bool
|
||||
|
||||
if alias == nil {
|
||||
return nil, fmt.Errorf("alias is nil")
|
||||
|
@ -428,12 +430,12 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
|
|||
return nil, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType)
|
||||
}
|
||||
|
||||
// Check if an entity already exists for the given alais
|
||||
// Check if an entity already exists for the given alias
|
||||
entity, err = i.entityByAliasFactors(alias.MountAccessor, alias.Name, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entity != nil {
|
||||
if entity != nil && changedAliasIndex(entity, alias) == -1 {
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
|
@ -445,40 +447,50 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
|
|||
defer txn.Abort()
|
||||
|
||||
// Check if an entity was created before acquiring the lock
|
||||
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
|
||||
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entity != nil {
|
||||
return entity, nil
|
||||
idx := changedAliasIndex(entity, alias)
|
||||
if idx == -1 {
|
||||
return entity, nil
|
||||
}
|
||||
a := entity.Aliases[idx]
|
||||
a.Metadata = alias.Metadata
|
||||
a.LastUpdateTime = ptypes.TimestampNow()
|
||||
|
||||
update = true
|
||||
}
|
||||
|
||||
entity = new(identity.Entity)
|
||||
err = i.sanitizeEntity(ctx, entity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !update {
|
||||
entity = new(identity.Entity)
|
||||
err = i.sanitizeEntity(ctx, entity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a new alias
|
||||
newAlias := &identity.Alias{
|
||||
CanonicalID: entity.ID,
|
||||
Name: alias.Name,
|
||||
MountAccessor: alias.MountAccessor,
|
||||
Metadata: alias.Metadata,
|
||||
MountPath: mountValidationResp.MountPath,
|
||||
MountType: mountValidationResp.MountType,
|
||||
}
|
||||
// Create a new alias
|
||||
newAlias := &identity.Alias{
|
||||
CanonicalID: entity.ID,
|
||||
Name: alias.Name,
|
||||
MountAccessor: alias.MountAccessor,
|
||||
Metadata: alias.Metadata,
|
||||
MountPath: mountValidationResp.MountPath,
|
||||
MountType: mountValidationResp.MountType,
|
||||
}
|
||||
|
||||
err = i.sanitizeAlias(ctx, newAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = i.sanitizeAlias(ctx, newAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.logger.Debug("creating a new entity", "alias", newAlias)
|
||||
i.logger.Debug("creating a new entity", "alias", newAlias)
|
||||
|
||||
// Append the new alias to the new entity
|
||||
entity.Aliases = []*identity.Alias{
|
||||
newAlias,
|
||||
// Append the new alias to the new entity
|
||||
entity.Aliases = []*identity.Alias{
|
||||
newAlias,
|
||||
}
|
||||
}
|
||||
|
||||
// Update MemDB and persist entity object
|
||||
|
@ -491,3 +503,17 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
|
|||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// changedAliasIndex searches an entity for changed alias metadata.
|
||||
//
|
||||
// If a match is found, the changed alias's index is returned. If no alias
|
||||
// names match or no metadata is different, -1 is returned.
|
||||
func changedAliasIndex(entity *identity.Entity, alias *logical.Alias) int {
|
||||
for i, a := range entity.Aliases {
|
||||
if a.Name == alias.Name && !strutil.EqualStringMaps(a.Metadata, alias.Metadata) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
|
|
@ -96,6 +96,9 @@ func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
|
|||
MountType: "github",
|
||||
MountAccessor: ghAccessor,
|
||||
Name: "githubuser",
|
||||
Metadata: map[string]string{
|
||||
"foo": "a",
|
||||
},
|
||||
}
|
||||
|
||||
entity, err := is.CreateOrFetchEntity(ctx, alias)
|
||||
|
@ -129,6 +132,71 @@ func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
|
|||
if entity.Aliases[0].Name != alias.Name {
|
||||
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
|
||||
}
|
||||
if diff := deep.Equal(entity.Aliases[0].Metadata, map[string]string{"foo": "a"}); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
// Add a new alias to the entity and verify its existence
|
||||
registerReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "entity-alias",
|
||||
Data: map[string]interface{}{
|
||||
"name": "githubuser2",
|
||||
"canonical_id": entity.ID,
|
||||
"mount_accessor": ghAccessor,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := is.HandleRequest(ctx, registerReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
entity, err = is.CreateOrFetchEntity(ctx, alias)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entity == nil {
|
||||
t.Fatalf("expected a non-nil entity")
|
||||
}
|
||||
|
||||
if len(entity.Aliases) != 2 {
|
||||
t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(entity.Aliases))
|
||||
}
|
||||
|
||||
if entity.Aliases[1].Name != "githubuser2" {
|
||||
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, "githubuser2")
|
||||
}
|
||||
|
||||
if diff := deep.Equal(entity.Aliases[1].Metadata, map[string]string(nil)); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
// Change the metadata of an existing alias and verify that
|
||||
// a the change takes effect only for the target alias.
|
||||
alias.Metadata = map[string]string{
|
||||
"foo": "zzzz",
|
||||
}
|
||||
|
||||
entity, err = is.CreateOrFetchEntity(ctx, alias)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entity == nil {
|
||||
t.Fatalf("expected a non-nil entity")
|
||||
}
|
||||
|
||||
if len(entity.Aliases) != 2 {
|
||||
t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(entity.Aliases))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(entity.Aliases[0].Metadata, map[string]string{"foo": "zzzz"}); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(entity.Aliases[1].Metadata, map[string]string(nil)); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityStore_EntityByAliasFactors(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue