From 5861c51e700a9fa9d1f54a3a03869f2e83ad1b1d Mon Sep 17 00:00:00 2001 From: Violet Hynes Date: Mon, 17 Oct 2022 14:46:25 -0400 Subject: [PATCH] VAULT-8719 Support data array for alias clash error response so UI/machines can understand error (#17459) * VAULT-8719 Support data array for alias clash error response so UI can understand error * VAULT-8719 Changelog * VAULT-8719 Update alias mount update logic * VAULT-8719 Further restrict IsError() --- changelog/17459.txt | 3 + http/handler.go | 10 ++ sdk/logical/response.go | 3 +- sdk/logical/response_util.go | 20 +++ vault/external_tests/identity/aliases_test.go | 169 ++++++++++++++++++ vault/identity_store_entities.go | 113 ++++++++---- vault/identity_store_util.go | 12 ++ 7 files changed, 296 insertions(+), 34 deletions(-) create mode 100644 changelog/17459.txt diff --git a/changelog/17459.txt b/changelog/17459.txt new file mode 100644 index 000000000..fd240c537 --- /dev/null +++ b/changelog/17459.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core/identity: Add machine-readable output to body of response upon alias clash during entity merge +``` \ No newline at end of file diff --git a/http/handler.go b/http/handler.go index 5bf848db2..6f5e29465 100644 --- a/http/handler.go +++ b/http/handler.go @@ -1178,6 +1178,10 @@ func respondError(w http.ResponseWriter, status int, err error) { logical.RespondError(w, status, err) } +func respondErrorAndData(w http.ResponseWriter, status int, data interface{}, err error) { + logical.RespondErrorAndData(w, status, data, err) +} + func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool { statusCode, newErr := logical.RespondErrorCommon(req, resp, err) if newErr == nil && statusCode == 0 { @@ -1193,6 +1197,12 @@ func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logic return true } + if resp != nil { + if data := resp.Data["data"]; data != nil { + respondErrorAndData(w, statusCode, data, newErr) + return true + } + } respondError(w, statusCode, newErr) return true } diff --git a/sdk/logical/response.go b/sdk/logical/response.go index 19194f524..0f8a2210e 100644 --- a/sdk/logical/response.go +++ b/sdk/logical/response.go @@ -92,7 +92,8 @@ func (r *Response) AddWarning(warning string) { // IsError returns true if this response seems to indicate an error. func (r *Response) IsError() bool { - return r != nil && r.Data != nil && len(r.Data) == 1 && r.Data["error"] != nil + // If the response data contains only an 'error' element, or an 'error' and a 'data' element only + return r != nil && r.Data != nil && r.Data["error"] != nil && (len(r.Data) == 1 || (r.Data["data"] != nil && len(r.Data) == 2)) } func (r *Response) Error() error { diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index a269fc639..4a9f61d56 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -182,3 +182,23 @@ func RespondError(w http.ResponseWriter, status int, err error) { enc := json.NewEncoder(w) enc.Encode(resp) } + +func RespondErrorAndData(w http.ResponseWriter, status int, data interface{}, err error) { + AdjustErrorStatusCode(&status, err) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + type ErrorAndDataResponse struct { + Errors []string `json:"errors"` + Data interface{} `json:"data""` + } + resp := &ErrorAndDataResponse{Errors: make([]string, 0, 1)} + if err != nil { + resp.Errors = append(resp.Errors, err.Error()) + } + resp.Data = data + + enc := json.NewEncoder(w) + enc.Encode(resp) +} diff --git a/vault/external_tests/identity/aliases_test.go b/vault/external_tests/identity/aliases_test.go index 9928ccdb9..d58540f8c 100644 --- a/vault/external_tests/identity/aliases_test.go +++ b/vault/external_tests/identity/aliases_test.go @@ -2,7 +2,9 @@ package identity import ( "context" + "encoding/json" "fmt" + "io" "strings" "testing" @@ -524,6 +526,9 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) { if err == nil { t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) } + if mergeResp != nil { + t.Fatalf("Response was non-nil. Resp:%#v", mergeResp) + } if !strings.Contains(err.Error(), "toEntity and at least one fromEntity have aliases with the same mount accessor") { t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %v", err) } @@ -553,6 +558,170 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) { } } +func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities_CheckRawRequest(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + + // Create userpass login for alice + _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + + // Perform entity merge as a Raw Request so we can investigate the response body + req := client.NewRequest("POST", "/v1/identity/entity/merge") + req.SetJSONBody(map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": []string{entityIdAlice}, + }) + + resp, err := client.RawRequest(req) + if err == nil { + t.Fatalf("Expected error but did not get one. Response: %v", resp) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + bodyString := string(bodyBytes) + + if resp.StatusCode != 400 { + t.Fatal("Incorrect status code for response") + } + + var mapOutput map[string]interface{} + if err = json.Unmarshal([]byte(bodyString), &mapOutput); err != nil { + t.Fatal(err) + } + + errorStrings, ok := mapOutput["errors"].([]interface{}) + if !ok { + t.Fatalf("error not present in response - full response: %s", bodyString) + } + + if len(errorStrings) != 1 { + t.Fatalf("Incorrect number of errors in response - full response: %s", bodyString) + } + + errorString, ok := errorStrings[0].(string) + if !ok { + t.Fatalf("error not present in response - full response: %s", bodyString) + } + + if !strings.Contains(errorString, "toEntity and at least one fromEntity have aliases with the same mount accessor") { + t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %s", errorString) + } + + dataArray, ok := mapOutput["data"].([]interface{}) + if !ok { + t.Fatalf("data not present in response - full response: %s", bodyString) + } + + if len(dataArray) != 2 { + t.Fatalf("Incorrect amount of clash data in response - full response: %s", bodyString) + } + + for _, data := range dataArray { + dataMap, ok := data.(map[string]interface{}) + if !ok { + t.Fatalf("data could not be understood - full response: %s", bodyString) + } + + entityId, ok := dataMap["entity_id"].(string) + if !ok { + t.Fatalf("entity_id not present in data - full response: %s", bodyString) + } + + if entityId != entityIdBob && entityId != entityIdAlice { + t.Fatalf("entityId not bob or alice - full response: %s", bodyString) + } + + entity, ok := dataMap["entity"].(string) + if !ok { + t.Fatalf("entity not present in data - full response: %s", bodyString) + } + + if entity != "bob-smith" && entity != "alice-smith" { + t.Fatalf("entity not bob or alice - full response: %s", bodyString) + } + + alias, ok := dataMap["alias"].(string) + if !ok { + t.Fatalf("alias not present in data - full response: %s", bodyString) + } + + if alias != "bob" && alias != "alice" { + t.Fatalf("alias not bob or alice - full response: %s", bodyString) + } + + mountPath, ok := dataMap["mount_path"].(string) + if !ok { + t.Fatalf("mountPath not present in data - full response: %s", bodyString) + } + + if mountPath != "auth/userpass/" { + t.Fatalf("mountPath not auth/userpass/ - full response: %s", bodyString) + } + + mount, ok := dataMap["mount"].(string) + if !ok { + t.Fatalf("mount not present in data - full response: %s", bodyString) + } + + if mount != "userpass" { + t.Fatalf("mount not userpass - full response: %s", bodyString) + } + } +} + func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T) { coreConfig := &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{ diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go index 11db95d8b..e9e74ab0a 100644 --- a/vault/identity_store_entities.go +++ b/vault/identity_store_entities.go @@ -190,9 +190,21 @@ func (i *IdentityStore) pathEntityMergeID() framework.OperationFunc { return nil, err } - userErr, intErr := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, conflictingAliasIDsToKeep, force, false, false, true, false) + userErr, intErr, aliases := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, conflictingAliasIDsToKeep, force, false, false, true, false) if userErr != nil { - return logical.ErrorResponse(userErr.Error()), nil + // Not an error due to alias clash, return like normal + if len(aliases) == 0 { + return logical.ErrorResponse(userErr.Error()), nil + } + // Alias clash error, so include additional details + resp := &logical.Response{ + Data: map[string]interface{}{ + "error": userErr.Error(), + "data": aliases, + }, + } + + return resp, nil } if intErr != nil { return nil, intErr @@ -734,29 +746,40 @@ func (i *IdentityStore) handlePathEntityListCommon(ctx context.Context, req *log } func (i *IdentityStore) mergeEntityAsPartOfUpsert(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityID string, persist bool) (error, error) { - return i.mergeEntity(ctx, txn, toEntity, []string{fromEntityID}, []string{}, true, false, true, persist, true) + err1, err2, _ := i.mergeEntity(ctx, txn, toEntity, []string{fromEntityID}, []string{}, true, false, true, persist, true) + return err1, err2 } -func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs, conflictingAliasIDsToKeep []string, force, grabLock, mergePolicies, persist, forceMergeAliases bool) (error, error) { +// A small type to return useful information to the UI after an entity clash +// Every alias involved in a clash will be returned. +type aliasClashInformation struct { + Alias string `json:"alias"` + Entity string `json:"entity"` + EntityId string `json:"entity_id"` + Mount string `json:"mount"` + MountPath string `json:"mount_path"` +} + +func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs, conflictingAliasIDsToKeep []string, force, grabLock, mergePolicies, persist, forceMergeAliases bool) (error, error, []aliasClashInformation) { if grabLock { i.lock.Lock() defer i.lock.Unlock() } if toEntity == nil { - return errors.New("entity id to merge to is invalid"), nil + return errors.New("entity id to merge to is invalid"), nil, nil } ns, err := namespace.FromContext(ctx) if err != nil { - return nil, err + return nil, err, nil } if toEntity.NamespaceID != ns.ID { - return errors.New("entity id to merge into does not belong to the request's namespace"), nil + return errors.New("entity id to merge into does not belong to the request's namespace"), nil, nil } if len(fromEntityIDs) > 1 && len(conflictingAliasIDsToKeep) > 1 { - return errors.New("aliases conflicts cannot be resolved with multiple from entity ids - merge one entity at a time"), nil + return errors.New("aliases conflicts cannot be resolved with multiple from entity ids - merge one entity at a time"), nil, nil } sanitizedFromEntityIDs := strutil.RemoveDuplicates(fromEntityIDs, false) @@ -764,27 +787,35 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // A map to check if there are any clashes between mount accessors for any of the sanitizedFromEntityIDs fromEntityAccessors := make(map[string]string) + // A list detailing all aliases where a clash has occurred, so that the error + // can be understood by the UI + aliasesInvolvedInClashes := make([]aliasClashInformation, 0) + + i.UpdateEntityWithMountInformation(toEntity) + // An error detailing if any alias clashes happen (shared mount accessor) var aliasClashError error for _, fromEntityID := range sanitizedFromEntityIDs { if fromEntityID == toEntity.ID { - return errors.New("to_entity_id should not be present in from_entity_ids"), nil + return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil } fromEntity, err := i.MemDBEntityByID(fromEntityID, false) if err != nil { - return nil, err + return nil, err, nil } if fromEntity == nil { - return errors.New("entity id to merge from is invalid"), nil + return errors.New("entity id to merge from is invalid"), nil, nil } if fromEntity.NamespaceID != toEntity.NamespaceID { - return errors.New("entity id to merge from does not belong to this namespace"), nil + return errors.New("entity id to merge from does not belong to this namespace"), nil, nil } + i.UpdateEntityWithMountInformation(fromEntity) + // If we're not resolving a conflict, we check to see if // any aliases conflict between the toEntity and this fromEntity: if !forceMergeAliases && len(conflictingAliasIDsToKeep) == 0 { @@ -793,7 +824,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // First, check to see if this alias clashes with an alias from any of the other fromEntities: id, mountAccessorInAnotherFromEntity := fromEntityAccessors[fromAlias.MountAccessor] if mountAccessorInAnotherFromEntity && (id != fromEntityID) { - return fmt.Errorf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", fromAlias.MountAccessor), nil + return fmt.Errorf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", fromAlias.MountAccessor), nil, nil } fromEntityAccessors[fromAlias.MountAccessor] = fromEntityID @@ -806,6 +837,22 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit aliasClashError = multierror.Append(aliasClashError, fmt.Errorf("mountAccessor: %s, toEntity ID: %s, fromEntity ID: %s, conflicting toEntity alias ID: %s, conflicting fromEntity alias ID: %s", toAlias.MountAccessor, toEntity.ID, fromEntityID, toAlias.ID, fromAlias.ID)) + + // Also add both to our summary of all clashes: + aliasesInvolvedInClashes = append(aliasesInvolvedInClashes, aliasClashInformation{ + Entity: toEntity.Name, + EntityId: toEntity.ID, + Alias: toAlias.Name, + Mount: toAlias.MountType, + MountPath: toAlias.MountPath, + }) + aliasesInvolvedInClashes = append(aliasesInvolvedInClashes, aliasClashInformation{ + Entity: fromEntity.Name, + EntityId: fromEntityID, + Alias: fromAlias.Name, + Mount: fromAlias.MountType, + MountPath: fromAlias.MountPath, + }) } } } @@ -814,7 +861,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit for configID, configSecret := range fromEntity.MFASecrets { _, ok := toEntity.MFASecrets[configID] if ok && !force { - return nil, fmt.Errorf("conflicting MFA config ID %q in entity ID %q", configID, fromEntity.ID) + return nil, fmt.Errorf("conflicting MFA config ID %q in entity ID %q", configID, fromEntity.ID), nil } else { if toEntity.MFASecrets == nil { toEntity.MFASecrets = make(map[string]*mfa.Secret) @@ -826,7 +873,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // Check alias clashes after validating every fromEntity, so that we have a full list of errors if aliasClashError != nil { - return aliasClashError, nil + return aliasClashError, nil, aliasesInvolvedInClashes } isPerfSecondaryOrStandby := i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || @@ -849,20 +896,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit for _, fromEntityID := range sanitizedFromEntityIDs { if fromEntityID == toEntity.ID { - return errors.New("to_entity_id should not be present in from_entity_ids"), nil + return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil } fromEntity, err := i.MemDBEntityByID(fromEntityID, true) if err != nil { - return nil, err + return nil, err, nil } if fromEntity == nil { - return errors.New("entity id to merge from is invalid"), nil + return errors.New("entity id to merge from is invalid"), nil, nil } if fromEntity.NamespaceID != toEntity.NamespaceID { - return errors.New("entity id to merge from does not belong to this namespace"), nil + return errors.New("entity id to merge from does not belong to this namespace"), nil, nil } for _, fromAlias := range fromEntity.Aliases { @@ -877,13 +924,13 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) if err != nil { - return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil } } else if strutil.StrListContains(conflictingAliasIDsToKeep, toAliasId) { i.logger.Info("Deleting from_entity alias during entity merge", "from_entity", fromEntityID, "deleted_alias", fromAlias.ID) err := i.MemDBDeleteAliasByIDInTxn(txn, fromAlias.ID, false) if err != nil { - return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil } // Continue to next alias, as there's no alias to merge left in the from_entity @@ -892,10 +939,10 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) if err != nil { - return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err), nil } } else { - return fmt.Errorf("conflicting mount accessors in following alias IDs and neither were present in conflicting_alias_ids_to_keep: %s, %s", fromAlias.ID, toAliasId), nil + return fmt.Errorf("conflicting mount accessors in following alias IDs and neither were present in conflicting_alias_ids_to_keep: %s, %s", fromAlias.ID, toAliasId), nil, nil } } } @@ -907,7 +954,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false) if err != nil { - return nil, fmt.Errorf("failed to update alias during merge: %w", err) + return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil } // Add the alias to the desired entity @@ -932,13 +979,13 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // internal and external groups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, fromEntity.ID, true, false) if err != nil { - return nil, err + return nil, err, nil } for _, group := range groups { group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, fromEntity.ID) err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby) if err != nil { - return nil, err + return nil, err, nil } fromEntityGroups = append(fromEntityGroups, group) @@ -947,14 +994,14 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // Delete the entity which we are merging from in MemDB using the same transaction err = i.MemDBDeleteEntityByIDInTxn(txn, fromEntity.ID) if err != nil { - return nil, err + return nil, err, nil } if persist && !isPerfSecondaryOrStandby { // Delete the entity which we are merging from in storage err = i.entityPacker.DeleteItem(ctx, fromEntity.ID) if err != nil { - return nil, err + return nil, err, nil } } } @@ -962,14 +1009,14 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // Update MemDB with changes to the entity we are merging to err = i.MemDBUpsertEntityInTxn(txn, toEntity) if err != nil { - return nil, err + return nil, err, nil } for _, group := range fromEntityGroups { group.MemberEntityIDs = append(group.MemberEntityIDs, toEntity.ID) err = i.UpsertGroupInTxn(ctx, txn, group, persist && !isPerfSecondaryOrStandby) if err != nil { - return nil, err + return nil, err, nil } } @@ -977,7 +1024,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit // Persist the entity which we are merging to toEntityAsAny, err := ptypes.MarshalAny(toEntity) if err != nil { - return nil, err + return nil, err, nil } item := &storagepacker.Item{ ID: toEntity.ID, @@ -986,11 +1033,11 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit err = i.entityPacker.PutItem(ctx, item) if err != nil { - return nil, err + return nil, err, nil } } - return nil, nil + return nil, nil, nil } var entityHelp = map[string][2]string{ diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 2eeef336f..cdb0bf0de 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -1092,6 +1092,18 @@ func (i *IdentityStore) MemDBEntityByIDInTxn(txn *memdb.Txn, entityID string, cl return entity, nil } +func (i *IdentityStore) UpdateEntityWithMountInformation(entity *identity.Entity) { + if entity != nil { + for _, alias := range entity.Aliases { + mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor) + if mountValidationResp != nil { + alias.MountType = mountValidationResp.MountType + alias.MountPath = mountValidationResp.MountPath + } + } + } +} + func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) { if entityID == "" { return nil, fmt.Errorf("missing entity id")