Return group memberships of entity during read (#3526)
* return group memberships of entity during read * Add implied group memberships to read response of entity * distinguish between all, direct and inherited group IDs of an entity * address review feedback * address review feedback * s/implied/inherited in tests
This commit is contained in:
parent
d7305a4681
commit
2af5b9274f
|
@ -446,6 +446,26 @@ func (i *IdentityStore) handleEntityReadCommon(entity *identity.Entity) (*logica
|
|||
// formats
|
||||
respData["aliases"] = aliasesToReturn
|
||||
|
||||
// Fetch the groups this entity belongs to and return their identifiers
|
||||
groups, inheritedGroups, err := i.groupsByEntityID(entity.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]string, len(groups))
|
||||
for i, group := range groups {
|
||||
groupIDs[i] = group.ID
|
||||
}
|
||||
respData["direct_group_ids"] = groupIDs
|
||||
|
||||
inheritedGroupIDs := make([]string, len(inheritedGroups))
|
||||
for i, group := range inheritedGroups {
|
||||
inheritedGroupIDs[i] = group.ID
|
||||
}
|
||||
respData["inherited_group_ids"] = inheritedGroupIDs
|
||||
|
||||
respData["group_ids"] = append(groupIDs, inheritedGroupIDs...)
|
||||
|
||||
return &logical.Response{
|
||||
Data: respData,
|
||||
}, nil
|
||||
|
|
|
@ -12,6 +12,86 @@ import (
|
|||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func TestIdentityStore_EntityReadGroupIDs(t *testing.T) {
|
||||
var err error
|
||||
var resp *logical.Response
|
||||
|
||||
i, _, _ := testIdentityStoreWithGithubAuth(t)
|
||||
|
||||
entityReq := &logical.Request{
|
||||
Path: "entity",
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
|
||||
resp, err = i.HandleRequest(entityReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
entityID := resp.Data["id"].(string)
|
||||
|
||||
groupReq := &logical.Request{
|
||||
Path: "group",
|
||||
Operation: logical.UpdateOperation,
|
||||
Data: map[string]interface{}{
|
||||
"member_entity_ids": []string{
|
||||
entityID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = i.HandleRequest(groupReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
groupID := resp.Data["id"].(string)
|
||||
|
||||
// Create another group with the above created group as its subgroup
|
||||
|
||||
groupReq.Data = map[string]interface{}{
|
||||
"member_group_ids": []string{groupID},
|
||||
}
|
||||
resp, err = i.HandleRequest(groupReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
inheritedGroupID := resp.Data["id"].(string)
|
||||
|
||||
lookupReq := &logical.Request{
|
||||
Path: "lookup/entity",
|
||||
Operation: logical.UpdateOperation,
|
||||
Data: map[string]interface{}{
|
||||
"type": "id",
|
||||
"id": entityID,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = i.HandleRequest(lookupReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
expected := []string{groupID, inheritedGroupID}
|
||||
actual := resp.Data["group_ids"].([]string)
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("bad: group_ids; expected: %#v\nactual: %#v\n", expected, actual)
|
||||
}
|
||||
|
||||
expected = []string{groupID}
|
||||
actual = resp.Data["direct_group_ids"].([]string)
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("bad: direct_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
|
||||
}
|
||||
|
||||
expected = []string{inheritedGroupID}
|
||||
actual = resp.Data["inherited_group_ids"].([]string)
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("bad: inherited_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityStore_EntityCreateUpdate(t *testing.T) {
|
||||
var err error
|
||||
var resp *logical.Response
|
||||
|
|
|
@ -546,11 +546,11 @@ func TestIdentityStore_GroupMultiCase(t *testing.T) {
|
|||
|
||||
/*
|
||||
Test groups hierarchy:
|
||||
eng
|
||||
| |
|
||||
vault ops
|
||||
| | | |
|
||||
kube identity build deploy
|
||||
------- eng(entityID3) -------
|
||||
| |
|
||||
----- vault ----- -- ops(entityID2) --
|
||||
| | | |
|
||||
kube(entityID1) identity build deploy
|
||||
*/
|
||||
func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
|
@ -808,27 +808,36 @@ func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
|
|||
t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies)
|
||||
}
|
||||
|
||||
groups, err := is.transitiveGroupsByEntityID(entityID1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(groups) != 3 {
|
||||
t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups))
|
||||
}
|
||||
|
||||
groups, err = is.transitiveGroupsByEntityID(entityID2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(groups) != 2 {
|
||||
t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups))
|
||||
}
|
||||
|
||||
groups, err = is.transitiveGroupsByEntityID(entityID3)
|
||||
groups, inheritedGroups, err := is.groupsByEntityID(entityID1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(groups) != 1 {
|
||||
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
|
||||
}
|
||||
if len(inheritedGroups) != 2 {
|
||||
t.Fatalf("bad: length of inheritedGroups; expected: 2, actual: %d", len(inheritedGroups))
|
||||
}
|
||||
|
||||
groups, inheritedGroups, err = is.groupsByEntityID(entityID2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(groups) != 1 {
|
||||
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
|
||||
}
|
||||
if len(inheritedGroups) != 1 {
|
||||
t.Fatalf("bad: length of inheritedGroups; expected: 1, actual: %d", len(inheritedGroups))
|
||||
}
|
||||
|
||||
groups, inheritedGroups, err = is.groupsByEntityID(entityID3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(groups) != 1 {
|
||||
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
|
||||
}
|
||||
if len(inheritedGroups) != 0 {
|
||||
t.Fatalf("bad: length of inheritedGroups; expected: 0, actual: %d", len(inheritedGroups))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1970,14 +1970,14 @@ func (i *IdentityStore) groupPoliciesByEntityID(entityID string) ([]string, erro
|
|||
return strutil.RemoveDuplicates(policies, false), nil
|
||||
}
|
||||
|
||||
func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity.Group, error) {
|
||||
func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) {
|
||||
if entityID == "" {
|
||||
return nil, fmt.Errorf("empty entity ID")
|
||||
return nil, nil, fmt.Errorf("empty entity ID")
|
||||
}
|
||||
|
||||
groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false)
|
||||
groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
visited := make(map[string]bool)
|
||||
|
@ -1985,7 +1985,7 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity
|
|||
for _, group := range groups {
|
||||
gGroups, err := i.collectGroupsReverseDFS(group, visited, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
tGroups = append(tGroups, gGroups...)
|
||||
}
|
||||
|
@ -2001,7 +2001,15 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity
|
|||
tGroups = append(tGroups, group)
|
||||
}
|
||||
|
||||
return tGroups, nil
|
||||
diff := diffGroups(groups, tGroups)
|
||||
|
||||
// For sanity
|
||||
// There should not be any group that gets deleted
|
||||
if len(diff.Deleted) != 0 {
|
||||
return nil, nil, fmt.Errorf("failed to diff group memberships")
|
||||
}
|
||||
|
||||
return diff.Unmodified, diff.New, nil
|
||||
}
|
||||
|
||||
func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {
|
||||
|
|
Loading…
Reference in New Issue