VAULT-8719 Support data array for alias clash error response so UI/machines can understand error (#17459)

* VAULT-8719 Support data array for alias clash error response so UI can understand error

* VAULT-8719 Changelog

* VAULT-8719 Update alias mount update logic

* VAULT-8719 Further restrict IsError()
This commit is contained in:
Violet Hynes 2022-10-17 14:46:25 -04:00 committed by GitHub
parent cf6e6ae87d
commit 5861c51e70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 296 additions and 34 deletions

3
changelog/17459.txt Normal file
View file

@ -0,0 +1,3 @@
```release-note:improvement
core/identity: Add machine-readable output to body of response upon alias clash during entity merge
```

View file

@ -1178,6 +1178,10 @@ func respondError(w http.ResponseWriter, status int, err error) {
logical.RespondError(w, status, err) logical.RespondError(w, status, err)
} }
func respondErrorAndData(w http.ResponseWriter, status int, data interface{}, err error) {
logical.RespondErrorAndData(w, status, data, err)
}
func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool { func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool {
statusCode, newErr := logical.RespondErrorCommon(req, resp, err) statusCode, newErr := logical.RespondErrorCommon(req, resp, err)
if newErr == nil && statusCode == 0 { if newErr == nil && statusCode == 0 {
@ -1193,6 +1197,12 @@ func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logic
return true return true
} }
if resp != nil {
if data := resp.Data["data"]; data != nil {
respondErrorAndData(w, statusCode, data, newErr)
return true
}
}
respondError(w, statusCode, newErr) respondError(w, statusCode, newErr)
return true return true
} }

View file

@ -92,7 +92,8 @@ func (r *Response) AddWarning(warning string) {
// IsError returns true if this response seems to indicate an error. // IsError returns true if this response seems to indicate an error.
func (r *Response) IsError() bool { func (r *Response) IsError() bool {
return r != nil && r.Data != nil && len(r.Data) == 1 && r.Data["error"] != nil // If the response data contains only an 'error' element, or an 'error' and a 'data' element only
return r != nil && r.Data != nil && r.Data["error"] != nil && (len(r.Data) == 1 || (r.Data["data"] != nil && len(r.Data) == 2))
} }
func (r *Response) Error() error { func (r *Response) Error() error {

View file

@ -182,3 +182,23 @@ func RespondError(w http.ResponseWriter, status int, err error) {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
enc.Encode(resp) enc.Encode(resp)
} }
func RespondErrorAndData(w http.ResponseWriter, status int, data interface{}, err error) {
AdjustErrorStatusCode(&status, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
type ErrorAndDataResponse struct {
Errors []string `json:"errors"`
Data interface{} `json:"data""`
}
resp := &ErrorAndDataResponse{Errors: make([]string, 0, 1)}
if err != nil {
resp.Errors = append(resp.Errors, err.Error())
}
resp.Data = data
enc := json.NewEncoder(w)
enc.Encode(resp)
}

View file

@ -2,7 +2,9 @@ package identity
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"strings" "strings"
"testing" "testing"
@ -524,6 +526,9 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) {
if err == nil { if err == nil {
t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp)
} }
if mergeResp != nil {
t.Fatalf("Response was non-nil. Resp:%#v", mergeResp)
}
if !strings.Contains(err.Error(), "toEntity and at least one fromEntity have aliases with the same mount accessor") { if !strings.Contains(err.Error(), "toEntity and at least one fromEntity have aliases with the same mount accessor") {
t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %v", err) t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %v", err)
} }
@ -553,6 +558,170 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) {
} }
} }
func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities_CheckRawRequest(t *testing.T) {
coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{
"password": "training",
})
if err != nil {
t.Fatal(err)
}
mounts, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
var mountAccessor string
for k, v := range mounts {
if k == "userpass/" {
mountAccessor = v.Accessor
break
}
}
if mountAccessor == "" {
t.Fatal("did not find userpass accessor")
}
_, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t)
// Create userpass login for alice
_, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{
"password": "training",
})
if err != nil {
t.Fatal(err)
}
_, entityIdAlice, _ := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t)
// Perform entity merge as a Raw Request so we can investigate the response body
req := client.NewRequest("POST", "/v1/identity/entity/merge")
req.SetJSONBody(map[string]interface{}{
"to_entity_id": entityIdBob,
"from_entity_ids": []string{entityIdAlice},
})
resp, err := client.RawRequest(req)
if err == nil {
t.Fatalf("Expected error but did not get one. Response: %v", resp)
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
bodyString := string(bodyBytes)
if resp.StatusCode != 400 {
t.Fatal("Incorrect status code for response")
}
var mapOutput map[string]interface{}
if err = json.Unmarshal([]byte(bodyString), &mapOutput); err != nil {
t.Fatal(err)
}
errorStrings, ok := mapOutput["errors"].([]interface{})
if !ok {
t.Fatalf("error not present in response - full response: %s", bodyString)
}
if len(errorStrings) != 1 {
t.Fatalf("Incorrect number of errors in response - full response: %s", bodyString)
}
errorString, ok := errorStrings[0].(string)
if !ok {
t.Fatalf("error not present in response - full response: %s", bodyString)
}
if !strings.Contains(errorString, "toEntity and at least one fromEntity have aliases with the same mount accessor") {
t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %s", errorString)
}
dataArray, ok := mapOutput["data"].([]interface{})
if !ok {
t.Fatalf("data not present in response - full response: %s", bodyString)
}
if len(dataArray) != 2 {
t.Fatalf("Incorrect amount of clash data in response - full response: %s", bodyString)
}
for _, data := range dataArray {
dataMap, ok := data.(map[string]interface{})
if !ok {
t.Fatalf("data could not be understood - full response: %s", bodyString)
}
entityId, ok := dataMap["entity_id"].(string)
if !ok {
t.Fatalf("entity_id not present in data - full response: %s", bodyString)
}
if entityId != entityIdBob && entityId != entityIdAlice {
t.Fatalf("entityId not bob or alice - full response: %s", bodyString)
}
entity, ok := dataMap["entity"].(string)
if !ok {
t.Fatalf("entity not present in data - full response: %s", bodyString)
}
if entity != "bob-smith" && entity != "alice-smith" {
t.Fatalf("entity not bob or alice - full response: %s", bodyString)
}
alias, ok := dataMap["alias"].(string)
if !ok {
t.Fatalf("alias not present in data - full response: %s", bodyString)
}
if alias != "bob" && alias != "alice" {
t.Fatalf("alias not bob or alice - full response: %s", bodyString)
}
mountPath, ok := dataMap["mount_path"].(string)
if !ok {
t.Fatalf("mountPath not present in data - full response: %s", bodyString)
}
if mountPath != "auth/userpass/" {
t.Fatalf("mountPath not auth/userpass/ - full response: %s", bodyString)
}
mount, ok := dataMap["mount"].(string)
if !ok {
t.Fatalf("mount not present in data - full response: %s", bodyString)
}
if mount != "userpass" {
t.Fatalf("mount not userpass - full response: %s", bodyString)
}
}
}
func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T) { func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T) {
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{ CredentialBackends: map[string]logical.Factory{

View file

@ -190,10 +190,22 @@ func (i *IdentityStore) pathEntityMergeID() framework.OperationFunc {
return nil, err return nil, err
} }
userErr, intErr := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, conflictingAliasIDsToKeep, force, false, false, true, false) userErr, intErr, aliases := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, conflictingAliasIDsToKeep, force, false, false, true, false)
if userErr != nil { if userErr != nil {
// Not an error due to alias clash, return like normal
if len(aliases) == 0 {
return logical.ErrorResponse(userErr.Error()), nil return logical.ErrorResponse(userErr.Error()), nil
} }
// Alias clash error, so include additional details
resp := &logical.Response{
Data: map[string]interface{}{
"error": userErr.Error(),
"data": aliases,
},
}
return resp, nil
}
if intErr != nil { if intErr != nil {
return nil, intErr return nil, intErr
} }
@ -734,29 +746,40 @@ func (i *IdentityStore) handlePathEntityListCommon(ctx context.Context, req *log
} }
func (i *IdentityStore) mergeEntityAsPartOfUpsert(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityID string, persist bool) (error, error) { func (i *IdentityStore) mergeEntityAsPartOfUpsert(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityID string, persist bool) (error, error) {
return i.mergeEntity(ctx, txn, toEntity, []string{fromEntityID}, []string{}, true, false, true, persist, true) err1, err2, _ := i.mergeEntity(ctx, txn, toEntity, []string{fromEntityID}, []string{}, true, false, true, persist, true)
return err1, err2
} }
func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs, conflictingAliasIDsToKeep []string, force, grabLock, mergePolicies, persist, forceMergeAliases bool) (error, error) { // A small type to return useful information to the UI after an entity clash
// Every alias involved in a clash will be returned.
type aliasClashInformation struct {
Alias string `json:"alias"`
Entity string `json:"entity"`
EntityId string `json:"entity_id"`
Mount string `json:"mount"`
MountPath string `json:"mount_path"`
}
func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs, conflictingAliasIDsToKeep []string, force, grabLock, mergePolicies, persist, forceMergeAliases bool) (error, error, []aliasClashInformation) {
if grabLock { if grabLock {
i.lock.Lock() i.lock.Lock()
defer i.lock.Unlock() defer i.lock.Unlock()
} }
if toEntity == nil { if toEntity == nil {
return errors.New("entity id to merge to is invalid"), nil return errors.New("entity id to merge to is invalid"), nil, nil
} }
ns, err := namespace.FromContext(ctx) ns, err := namespace.FromContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
if toEntity.NamespaceID != ns.ID { if toEntity.NamespaceID != ns.ID {
return errors.New("entity id to merge into does not belong to the request's namespace"), nil return errors.New("entity id to merge into does not belong to the request's namespace"), nil, nil
} }
if len(fromEntityIDs) > 1 && len(conflictingAliasIDsToKeep) > 1 { if len(fromEntityIDs) > 1 && len(conflictingAliasIDsToKeep) > 1 {
return errors.New("aliases conflicts cannot be resolved with multiple from entity ids - merge one entity at a time"), nil return errors.New("aliases conflicts cannot be resolved with multiple from entity ids - merge one entity at a time"), nil, nil
} }
sanitizedFromEntityIDs := strutil.RemoveDuplicates(fromEntityIDs, false) sanitizedFromEntityIDs := strutil.RemoveDuplicates(fromEntityIDs, false)
@ -764,27 +787,35 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// A map to check if there are any clashes between mount accessors for any of the sanitizedFromEntityIDs // A map to check if there are any clashes between mount accessors for any of the sanitizedFromEntityIDs
fromEntityAccessors := make(map[string]string) fromEntityAccessors := make(map[string]string)
// A list detailing all aliases where a clash has occurred, so that the error
// can be understood by the UI
aliasesInvolvedInClashes := make([]aliasClashInformation, 0)
i.UpdateEntityWithMountInformation(toEntity)
// An error detailing if any alias clashes happen (shared mount accessor) // An error detailing if any alias clashes happen (shared mount accessor)
var aliasClashError error var aliasClashError error
for _, fromEntityID := range sanitizedFromEntityIDs { for _, fromEntityID := range sanitizedFromEntityIDs {
if fromEntityID == toEntity.ID { if fromEntityID == toEntity.ID {
return errors.New("to_entity_id should not be present in from_entity_ids"), nil return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
} }
fromEntity, err := i.MemDBEntityByID(fromEntityID, false) fromEntity, err := i.MemDBEntityByID(fromEntityID, false)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
if fromEntity == nil { if fromEntity == nil {
return errors.New("entity id to merge from is invalid"), nil return errors.New("entity id to merge from is invalid"), nil, nil
} }
if fromEntity.NamespaceID != toEntity.NamespaceID { if fromEntity.NamespaceID != toEntity.NamespaceID {
return errors.New("entity id to merge from does not belong to this namespace"), nil return errors.New("entity id to merge from does not belong to this namespace"), nil, nil
} }
i.UpdateEntityWithMountInformation(fromEntity)
// If we're not resolving a conflict, we check to see if // If we're not resolving a conflict, we check to see if
// any aliases conflict between the toEntity and this fromEntity: // any aliases conflict between the toEntity and this fromEntity:
if !forceMergeAliases && len(conflictingAliasIDsToKeep) == 0 { if !forceMergeAliases && len(conflictingAliasIDsToKeep) == 0 {
@ -793,7 +824,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// First, check to see if this alias clashes with an alias from any of the other fromEntities: // First, check to see if this alias clashes with an alias from any of the other fromEntities:
id, mountAccessorInAnotherFromEntity := fromEntityAccessors[fromAlias.MountAccessor] id, mountAccessorInAnotherFromEntity := fromEntityAccessors[fromAlias.MountAccessor]
if mountAccessorInAnotherFromEntity && (id != fromEntityID) { if mountAccessorInAnotherFromEntity && (id != fromEntityID) {
return fmt.Errorf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", fromAlias.MountAccessor), nil return fmt.Errorf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", fromAlias.MountAccessor), nil, nil
} }
fromEntityAccessors[fromAlias.MountAccessor] = fromEntityID fromEntityAccessors[fromAlias.MountAccessor] = fromEntityID
@ -806,6 +837,22 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
aliasClashError = multierror.Append(aliasClashError, aliasClashError = multierror.Append(aliasClashError,
fmt.Errorf("mountAccessor: %s, toEntity ID: %s, fromEntity ID: %s, conflicting toEntity alias ID: %s, conflicting fromEntity alias ID: %s", fmt.Errorf("mountAccessor: %s, toEntity ID: %s, fromEntity ID: %s, conflicting toEntity alias ID: %s, conflicting fromEntity alias ID: %s",
toAlias.MountAccessor, toEntity.ID, fromEntityID, toAlias.ID, fromAlias.ID)) toAlias.MountAccessor, toEntity.ID, fromEntityID, toAlias.ID, fromAlias.ID))
// Also add both to our summary of all clashes:
aliasesInvolvedInClashes = append(aliasesInvolvedInClashes, aliasClashInformation{
Entity: toEntity.Name,
EntityId: toEntity.ID,
Alias: toAlias.Name,
Mount: toAlias.MountType,
MountPath: toAlias.MountPath,
})
aliasesInvolvedInClashes = append(aliasesInvolvedInClashes, aliasClashInformation{
Entity: fromEntity.Name,
EntityId: fromEntityID,
Alias: fromAlias.Name,
Mount: fromAlias.MountType,
MountPath: fromAlias.MountPath,
})
} }
} }
} }
@ -814,7 +861,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
for configID, configSecret := range fromEntity.MFASecrets { for configID, configSecret := range fromEntity.MFASecrets {
_, ok := toEntity.MFASecrets[configID] _, ok := toEntity.MFASecrets[configID]
if ok && !force { if ok && !force {
return nil, fmt.Errorf("conflicting MFA config ID %q in entity ID %q", configID, fromEntity.ID) return nil, fmt.Errorf("conflicting MFA config ID %q in entity ID %q", configID, fromEntity.ID), nil
} else { } else {
if toEntity.MFASecrets == nil { if toEntity.MFASecrets == nil {
toEntity.MFASecrets = make(map[string]*mfa.Secret) toEntity.MFASecrets = make(map[string]*mfa.Secret)
@ -826,7 +873,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// Check alias clashes after validating every fromEntity, so that we have a full list of errors // Check alias clashes after validating every fromEntity, so that we have a full list of errors
if aliasClashError != nil { if aliasClashError != nil {
return aliasClashError, nil return aliasClashError, nil, aliasesInvolvedInClashes
} }
isPerfSecondaryOrStandby := i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || isPerfSecondaryOrStandby := i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) ||
@ -849,20 +896,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
for _, fromEntityID := range sanitizedFromEntityIDs { for _, fromEntityID := range sanitizedFromEntityIDs {
if fromEntityID == toEntity.ID { if fromEntityID == toEntity.ID {
return errors.New("to_entity_id should not be present in from_entity_ids"), nil return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
} }
fromEntity, err := i.MemDBEntityByID(fromEntityID, true) fromEntity, err := i.MemDBEntityByID(fromEntityID, true)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
if fromEntity == nil { if fromEntity == nil {
return errors.New("entity id to merge from is invalid"), nil return errors.New("entity id to merge from is invalid"), nil, nil
} }
if fromEntity.NamespaceID != toEntity.NamespaceID { if fromEntity.NamespaceID != toEntity.NamespaceID {
return errors.New("entity id to merge from does not belong to this namespace"), nil return errors.New("entity id to merge from does not belong to this namespace"), nil, nil
} }
for _, fromAlias := range fromEntity.Aliases { for _, fromAlias := range fromEntity.Aliases {
@ -877,13 +924,13 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil
} }
} else if strutil.StrListContains(conflictingAliasIDsToKeep, toAliasId) { } else if strutil.StrListContains(conflictingAliasIDsToKeep, toAliasId) {
i.logger.Info("Deleting from_entity alias during entity merge", "from_entity", fromEntityID, "deleted_alias", fromAlias.ID) i.logger.Info("Deleting from_entity alias during entity merge", "from_entity", fromEntityID, "deleted_alias", fromAlias.ID)
err := i.MemDBDeleteAliasByIDInTxn(txn, fromAlias.ID, false) err := i.MemDBDeleteAliasByIDInTxn(txn, fromAlias.ID, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil
} }
// Continue to next alias, as there's no alias to merge left in the from_entity // Continue to next alias, as there's no alias to merge left in the from_entity
@ -892,10 +939,10 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil
} }
} else { } else {
return fmt.Errorf("conflicting mount accessors in following alias IDs and neither were present in conflicting_alias_ids_to_keep: %s, %s", fromAlias.ID, toAliasId), nil return fmt.Errorf("conflicting mount accessors in following alias IDs and neither were present in conflicting_alias_ids_to_keep: %s, %s", fromAlias.ID, toAliasId), nil, nil
} }
} }
} }
@ -907,7 +954,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false) err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update alias during merge: %w", err) return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil
} }
// Add the alias to the desired entity // Add the alias to the desired entity
@ -932,13 +979,13 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// internal and external // internal and external
groups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, fromEntity.ID, true, false) groups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, fromEntity.ID, true, false)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
for _, group := range groups { for _, group := range groups {
group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, fromEntity.ID) group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, fromEntity.ID)
err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby) err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
fromEntityGroups = append(fromEntityGroups, group) fromEntityGroups = append(fromEntityGroups, group)
@ -947,14 +994,14 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// Delete the entity which we are merging from in MemDB using the same transaction // Delete the entity which we are merging from in MemDB using the same transaction
err = i.MemDBDeleteEntityByIDInTxn(txn, fromEntity.ID) err = i.MemDBDeleteEntityByIDInTxn(txn, fromEntity.ID)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
if persist && !isPerfSecondaryOrStandby { if persist && !isPerfSecondaryOrStandby {
// Delete the entity which we are merging from in storage // Delete the entity which we are merging from in storage
err = i.entityPacker.DeleteItem(ctx, fromEntity.ID) err = i.entityPacker.DeleteItem(ctx, fromEntity.ID)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
} }
} }
@ -962,14 +1009,14 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// Update MemDB with changes to the entity we are merging to // Update MemDB with changes to the entity we are merging to
err = i.MemDBUpsertEntityInTxn(txn, toEntity) err = i.MemDBUpsertEntityInTxn(txn, toEntity)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
for _, group := range fromEntityGroups { for _, group := range fromEntityGroups {
group.MemberEntityIDs = append(group.MemberEntityIDs, toEntity.ID) group.MemberEntityIDs = append(group.MemberEntityIDs, toEntity.ID)
err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby) err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
} }
@ -977,7 +1024,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
// Persist the entity which we are merging to // Persist the entity which we are merging to
toEntityAsAny, err := ptypes.MarshalAny(toEntity) toEntityAsAny, err := ptypes.MarshalAny(toEntity)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
item := &storagepacker.Item{ item := &storagepacker.Item{
ID: toEntity.ID, ID: toEntity.ID,
@ -986,11 +1033,11 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
err = i.entityPacker.PutItem(ctx, item) err = i.entityPacker.PutItem(ctx, item)
if err != nil { if err != nil {
return nil, err return nil, err, nil
} }
} }
return nil, nil return nil, nil, nil
} }
var entityHelp = map[string][2]string{ var entityHelp = map[string][2]string{

View file

@ -1092,6 +1092,18 @@ func (i *IdentityStore) MemDBEntityByIDInTxn(txn *memdb.Txn, entityID string, cl
return entity, nil return entity, nil
} }
func (i *IdentityStore) UpdateEntityWithMountInformation(entity *identity.Entity) {
if entity != nil {
for _, alias := range entity.Aliases {
mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor)
if mountValidationResp != nil {
alias.MountType = mountValidationResp.MountType
alias.MountPath = mountValidationResp.MountPath
}
}
}
}
func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) { func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) {
if entityID == "" { if entityID == "" {
return nil, fmt.Errorf("missing entity id") return nil, fmt.Errorf("missing entity id")