diff --git a/changelog/15055.txt b/changelog/15055.txt new file mode 100644 index 000000000..648aa2998 --- /dev/null +++ b/changelog/15055.txt @@ -0,0 +1,3 @@ +```release-note:bug +identity: deduplicate policies when creating/updating identity groups +``` \ No newline at end of file diff --git a/vault/identity_store_groups_test.go b/vault/identity_store_groups_test.go index e3f1486f7..8e1d38038 100644 --- a/vault/identity_store_groups_test.go +++ b/vault/identity_store_groups_test.go @@ -774,6 +774,94 @@ func TestIdentityStore_GroupsCreateUpdate(t *testing.T) { } } +func TestIdentityStore_GroupsCreateUpdateDuplicatePolicy(t *testing.T) { + var resp *logical.Response + var err error + + ctx := namespace.RootContext(nil) + is, _, _ := testIdentityStoreWithGithubAuth(ctx, t) + + // Create a group with the above created 2 entities as its members + groupData := map[string]interface{}{ + "policies": []string{"testpolicy1", "testpolicy2"}, + "metadata": []string{"testkey1=testvalue1", "testkey2=testvalue2"}, + } + + // Create a group and get its ID + groupReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "group", + Data: groupData, + } + + // Create a group with the above 2 groups as its members + resp, err = is.HandleRequest(ctx, groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + groupID := resp.Data["id"].(string) + + // Read the group using its iD and check if all the fields are properly + // set + groupReq = &logical.Request{ + Operation: logical.ReadOperation, + Path: "group/id/" + groupID, + } + resp, err = is.HandleRequest(ctx, groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData := map[string]interface{}{ + "policies": []string{"testpolicy1", "testpolicy2"}, + "metadata": map[string]string{ + "testkey1": "testvalue1", + "testkey2": "testvalue2", + }, + "parent_group_ids": []string(nil), + } + expectedData["id"] = resp.Data["id"] + expectedData["type"] = resp.Data["type"] + expectedData["name"] = resp.Data["name"] + expectedData["creation_time"] = resp.Data["creation_time"] + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + expectedData["alias"] = resp.Data["alias"] + expectedData["namespace_id"] = "root" + expectedData["member_group_ids"] = resp.Data["member_group_ids"] + expectedData["member_entity_ids"] = resp.Data["member_entity_ids"] + + if diff := deep.Equal(expectedData, resp.Data); diff != nil { + t.Fatal(diff) + } + + // Update the policies and metadata in the group + groupReq.Operation = logical.UpdateOperation + groupReq.Data = groupData + + // Update by setting ID in the param + groupData["id"] = groupID + groupData["policies"] = []string{"updatedpolicy1", "updatedpolicy2", "updatedpolicy2"} + resp, err = is.HandleRequest(ctx, groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + // Check if updates are reflected + groupReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(ctx, groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData["policies"] = []string{"updatedpolicy1", "updatedpolicy2"} + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + if !reflect.DeepEqual(expectedData, resp.Data) { + t.Fatalf("bad: group data; expected: %#v\n actual: %#v\n", expectedData, resp.Data) + } +} + func TestIdentityStore_GroupsCRUD_ByID(t *testing.T) { var resp *logical.Response var err error diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 5f9e891e7..b6baea019 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -1483,6 +1483,9 @@ func (i *IdentityStore) sanitizeAndUpsertGroup(ctx context.Context, group *ident } } + // Remove duplicate policies + group.Policies = strutil.RemoveDuplicates(group.Policies, false) + txn := i.db.Txn(true) defer txn.Abort()