Return group memberships of entity during read (#3526)

* return group memberships of entity during read

* Add implied group memberships to read response of entity

* distinguish between all, direct and inherited group IDs of an entity

* address review feedback

* address review feedback

* s/implied/inherited in tests
This commit is contained in:
Vishal Nayak 2017-11-06 13:01:48 -05:00 committed by GitHub
parent d7305a4681
commit 2af5b9274f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 28 deletions

View File

@ -446,6 +446,26 @@ func (i *IdentityStore) handleEntityReadCommon(entity *identity.Entity) (*logica
// formats
respData["aliases"] = aliasesToReturn
// Fetch the groups this entity belongs to and return their identifiers
groups, inheritedGroups, err := i.groupsByEntityID(entity.ID)
if err != nil {
return nil, err
}
groupIDs := make([]string, len(groups))
for i, group := range groups {
groupIDs[i] = group.ID
}
respData["direct_group_ids"] = groupIDs
inheritedGroupIDs := make([]string, len(inheritedGroups))
for i, group := range inheritedGroups {
inheritedGroupIDs[i] = group.ID
}
respData["inherited_group_ids"] = inheritedGroupIDs
respData["group_ids"] = append(groupIDs, inheritedGroupIDs...)
return &logical.Response{
Data: respData,
}, nil

View File

@ -12,6 +12,86 @@ import (
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_EntityReadGroupIDs(t *testing.T) {
var err error
var resp *logical.Response
i, _, _ := testIdentityStoreWithGithubAuth(t)
entityReq := &logical.Request{
Path: "entity",
Operation: logical.UpdateOperation,
}
resp, err = i.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
entityID := resp.Data["id"].(string)
groupReq := &logical.Request{
Path: "group",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"member_entity_ids": []string{
entityID,
},
},
}
resp, err = i.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
groupID := resp.Data["id"].(string)
// Create another group with the above created group as its subgroup
groupReq.Data = map[string]interface{}{
"member_group_ids": []string{groupID},
}
resp, err = i.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
inheritedGroupID := resp.Data["id"].(string)
lookupReq := &logical.Request{
Path: "lookup/entity",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": "id",
"id": entityID,
},
}
resp, err = i.HandleRequest(lookupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
expected := []string{groupID, inheritedGroupID}
actual := resp.Data["group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}
expected = []string{groupID}
actual = resp.Data["direct_group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: direct_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}
expected = []string{inheritedGroupID}
actual = resp.Data["inherited_group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: inherited_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}
}
func TestIdentityStore_EntityCreateUpdate(t *testing.T) {
var err error
var resp *logical.Response

View File

@ -546,11 +546,11 @@ func TestIdentityStore_GroupMultiCase(t *testing.T) {
/*
Test groups hierarchy:
eng
| |
vault ops
| | | |
kube identity build deploy
------- eng(entityID3) -------
| |
----- vault ----- -- ops(entityID2) --
| | | |
kube(entityID1) identity build deploy
*/
func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
var resp *logical.Response
@ -808,27 +808,36 @@ func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies)
}
groups, err := is.transitiveGroupsByEntityID(entityID1)
if err != nil {
t.Fatal(err)
}
if len(groups) != 3 {
t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups))
}
groups, err = is.transitiveGroupsByEntityID(entityID2)
if err != nil {
t.Fatal(err)
}
if len(groups) != 2 {
t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups))
}
groups, err = is.transitiveGroupsByEntityID(entityID3)
groups, inheritedGroups, err := is.groupsByEntityID(entityID1)
if err != nil {
t.Fatal(err)
}
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 2 {
t.Fatalf("bad: length of inheritedGroups; expected: 2, actual: %d", len(inheritedGroups))
}
groups, inheritedGroups, err = is.groupsByEntityID(entityID2)
if err != nil {
t.Fatal(err)
}
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 1 {
t.Fatalf("bad: length of inheritedGroups; expected: 1, actual: %d", len(inheritedGroups))
}
groups, inheritedGroups, err = is.groupsByEntityID(entityID3)
if err != nil {
t.Fatal(err)
}
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 0 {
t.Fatalf("bad: length of inheritedGroups; expected: 0, actual: %d", len(inheritedGroups))
}
}

View File

@ -1970,14 +1970,14 @@ func (i *IdentityStore) groupPoliciesByEntityID(entityID string) ([]string, erro
return strutil.RemoveDuplicates(policies, false), nil
}
func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity.Group, error) {
func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) {
if entityID == "" {
return nil, fmt.Errorf("empty entity ID")
return nil, nil, fmt.Errorf("empty entity ID")
}
groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false)
groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false)
if err != nil {
return nil, err
return nil, nil, err
}
visited := make(map[string]bool)
@ -1985,7 +1985,7 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity
for _, group := range groups {
gGroups, err := i.collectGroupsReverseDFS(group, visited, nil)
if err != nil {
return nil, err
return nil, nil, err
}
tGroups = append(tGroups, gGroups...)
}
@ -2001,7 +2001,15 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity
tGroups = append(tGroups, group)
}
return tGroups, nil
diff := diffGroups(groups, tGroups)
// For sanity
// There should not be any group that gets deleted
if len(diff.Deleted) != 0 {
return nil, nil, fmt.Errorf("failed to diff group memberships")
}
return diff.Unmodified, diff.New, nil
}
func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {