diff --git a/changelog/16539.txt b/changelog/16539.txt new file mode 100644 index 000000000..9927329b5 --- /dev/null +++ b/changelog/16539.txt @@ -0,0 +1,3 @@ +```release-note:change +core/entities: Fixed stranding of aliases upon entity merge, and require explicit selection of which aliases should be kept when some must be deleted +``` diff --git a/vault/external_tests/identity/aliases_test.go b/vault/external_tests/identity/aliases_test.go index 059524275..9928ccdb9 100644 --- a/vault/external_tests/identity/aliases_test.go +++ b/vault/external_tests/identity/aliases_test.go @@ -1,9 +1,13 @@ package identity import ( + "context" + "fmt" + "strings" "testing" "github.com/hashicorp/vault/api" + auth "github.com/hashicorp/vault/api/auth/userpass" "github.com/hashicorp/vault/builtin/credential/github" "github.com/hashicorp/vault/builtin/credential/userpass" vaulthttp "github.com/hashicorp/vault/http" @@ -250,3 +254,528 @@ func TestIdentityStore_RenameAlias_CannotMergeEntity(t *testing.T) { t.Fatal("expected rename over existing entity to fail") } } + +func TestIdentityStore_MergeEntities_FailsDueToClash(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + + // Create userpass login for alice + _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + + // Perform entity merge + mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": entityIdAlice, + }) + if err == nil { + t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) + } + if !strings.Contains(err.Error(), "toEntity and at least one fromEntity have aliases with the same mount accessor") { + t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %v", err) + } + if !strings.Contains(err.Error(), entityIdAlice) { + t.Fatalf("Did not identify alice's entity (%s) as conflicting. Error: %v", entityIdAlice, err) + } + if !strings.Contains(err.Error(), entityIdBob) { + t.Fatalf("Did not identify bob's entity (%s) as conflicting. Error: %v", entityIdBob, err) + } + if !strings.Contains(err.Error(), aliasIdAlice) { + t.Fatalf("Did not identify alice's alias (%s) as conflicting. Error: %v", aliasIdAlice, err) + } + if !strings.Contains(err.Error(), aliasIdBob) { + t.Fatalf("Did not identify bob's alias (%s) as conflicting. Error: %v", aliasIdBob, err) + } + if !strings.Contains(err.Error(), mountAccessor) { + t.Fatalf("Did not identify mount accessor %s as being reason for conflict. Error: %v", mountAccessor, err) + } +} + +func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + "github": github.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + err = client.Sys().EnableAuthWithOptions("github", &api.EnableAuthOptions{ + Type: "github", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + var mountAccessorGitHub string + for k, v := range mounts { + if k == "github/" { + mountAccessorGitHub = v.Accessor + break + } + } + if mountAccessorGitHub == "" { + t.Fatal("did not find github accessor") + } + + _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessorGitHub, "alice-smith", "alice", t) + _, entityIdClara, _ := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "clara", t) + + // Perform entity merge + mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": []string{entityIdAlice, entityIdClara}, + }) + if err == nil { + t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) + } + if !strings.Contains(err.Error(), fmt.Sprintf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", mountAccessorGitHub)) { + t.Fatalf("Error was not due to conflicting alias mount accessors in fromEntities. Error: %v", err) + } +} + +func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + "github": github.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + err = client.Sys().EnableAuthWithOptions("github", &api.EnableAuthOptions{ + Type: "github", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob-github", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + var mountAccessorGitHub string + for k, v := range mounts { + if k == "github/" { + mountAccessorGitHub = v.Accessor + break + } + } + if mountAccessorGitHub == "" { + t.Fatal("did not find github accessor") + } + + _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + + aliasResp, err := client.Logical().Write("identity/entity-alias", map[string]interface{}{ + "name": "bob-github", + "canonical_id": entityIdBob, + "mount_accessor": mountAccessorGitHub, + }) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, aliasResp) + } + + aliasIdBobGitHub := aliasResp.Data["id"].(string) + if aliasIdBobGitHub == "" { + t.Fatal("Alias ID not present in response") + } + + // Create userpass login for alice + _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ + "password": "training", + }) + if err != nil { + t.Fatal(err) + } + + _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + _, entityIdClara, aliasIdClara := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "clara", t) + + // Perform entity merge + mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": []string{entityIdAlice, entityIdClara}, + }) + if err == nil { + t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) + } + if !strings.Contains(err.Error(), "toEntity and at least one fromEntity have aliases with the same mount accessor") { + t.Fatalf("Error was not due to conflicting alias mount accessors. Error: %v", err) + } + if !strings.Contains(err.Error(), entityIdAlice) { + t.Fatalf("Did not identify alice's entity (%s) as conflicting. Error: %v", entityIdAlice, err) + } + if !strings.Contains(err.Error(), entityIdBob) { + t.Fatalf("Did not identify bob's entity (%s) as conflicting. Error: %v", entityIdBob, err) + } + if !strings.Contains(err.Error(), entityIdClara) { + t.Fatalf("Did not identify clara's alias (%s) as conflicting. Error: %v", entityIdClara, err) + } + if !strings.Contains(err.Error(), aliasIdAlice) { + t.Fatalf("Did not identify alice's alias (%s) as conflicting. Error: %v", aliasIdAlice, err) + } + if !strings.Contains(err.Error(), aliasIdBob) { + t.Fatalf("Did not identify bob's alias (%s) as conflicting. Error: %v", aliasIdBob, err) + } + if !strings.Contains(err.Error(), aliasIdClara) { + t.Fatalf("Did not identify bob's alias (%s) as conflicting. Error: %v", aliasIdClara, err) + } + if !strings.Contains(err.Error(), mountAccessor) { + t.Fatalf("Did not identify mount accessor %s as being reason for conflict. Error: %v", mountAccessor, err) + } + if !strings.Contains(err.Error(), mountAccessorGitHub) { + t.Fatalf("Did not identify mount accessor %s as being reason for conflict. Error: %v", mountAccessorGitHub, err) + } +} + +func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + _, err = client.Logical().Write("auth/userpass/login/bob", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + + // Create userpass login for alice + _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + _, err = client.Logical().Write("auth/userpass/login/alice", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + + _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + + // Try and login with alias 2 (alice) pre-merge + userpassAuth, err := auth.NewUserpassAuth("alice", &auth.Password{FromString: "testpassword"}) + if err != nil { + t.Fatal(err) + } + loginResp, err := client.Logical().Write("auth/userpass/login/alice", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, loginResp) + } + if loginResp.Auth == nil { + t.Fatalf("Request auth is nil, something has gone wrong - resp:%#v", loginResp) + } + loginEntityId := loginResp.Auth.EntityID + if loginEntityId != entityIdAlice { + t.Fatalf("Login entity ID is not Alice. loginEntityId:%s aliceEntityId:%s", loginEntityId, entityIdAlice) + } + + // Perform entity merge + mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": entityIdAlice, + "conflicting_alias_ids_to_keep": aliasIdBob, + }) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, mergeResp) + } + + // Delete entity id 1 (bob) + deleteResp, err := client.Logical().Delete(fmt.Sprintf("identity/entity/id/%s", entityIdBob)) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, deleteResp) + } + + // Try and login with alias 2 (alice) post-merge + // Notably, this login method sets the client token, which is why we didn't use it above + loginResp, err = client.Auth().Login(context.Background(), userpassAuth) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, loginResp) + } + if loginResp.Auth == nil { + t.Fatalf("Request auth is nil, something has gone wrong - resp:%#v", loginResp) + } + if loginEntityId != entityIdAlice { + t.Fatalf("Login entity ID is not Alice. loginEntityId:%s aliceEntityId:%s", loginEntityId, entityIdAlice) + } +} + +func TestIdentityStore_MergeEntities_FailsDueToMultipleClashMergesAttempted(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + "github": github.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + err = client.Sys().EnableAuthWithOptions("github", &api.EnableAuthOptions{ + Type: "github", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/bob-github", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + + mounts, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var mountAccessor string + for k, v := range mounts { + if k == "userpass/" { + mountAccessor = v.Accessor + break + } + } + if mountAccessor == "" { + t.Fatal("did not find userpass accessor") + } + + var mountAccessorGitHub string + for k, v := range mounts { + if k == "github/" { + mountAccessorGitHub = v.Accessor + break + } + } + if mountAccessorGitHub == "" { + t.Fatal("did not find github accessor") + } + + _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + aliasResp, err := client.Logical().Write("identity/entity-alias", map[string]interface{}{ + "name": "bob-github", + "canonical_id": entityIdBob, + "mount_accessor": mountAccessorGitHub, + }) + if err != nil { + t.Fatalf("err:%v resp:%#v", err, aliasResp) + } + + aliasIdBobGitHub := aliasResp.Data["id"].(string) + if aliasIdBobGitHub == "" { + t.Fatal("Alias ID not present in response") + } + + // Create userpass login for alice + _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + + _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + _, entityIdClara, aliasIdClara := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "alice", t) + + // Perform entity merge + mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ + "to_entity_id": entityIdBob, + "from_entity_ids": []string{entityIdAlice, entityIdClara}, + "conflicting_alias_ids_to_keep": []string{aliasIdAlice, aliasIdClara}, + }) + if err == nil { + t.Fatalf("Expected error upon merge. Resp:%#v", mergeResp) + } + if !strings.Contains(err.Error(), "merge one entity at a time") { + t.Fatalf("did not error for the right reason. Error: %v", err) + } +} diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index d0871cc28..15d7bf009 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -18,7 +18,7 @@ import ( "github.com/hashicorp/vault/vault" ) -func createEntityAndAlias(client *api.Client, mountAccessor, entityName, aliasName string, t *testing.T) (*api.Client, string) { +func createEntityAndAlias(client *api.Client, mountAccessor, entityName, aliasName string, t *testing.T) (*api.Client, string, string) { _, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/users/%s", aliasName), map[string]interface{}{ "password": "testpassword", }) @@ -40,7 +40,7 @@ func createEntityAndAlias(client *api.Client, mountAccessor, entityName, aliasNa } entityID := resp.Data["id"].(string) - _, err = client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{ + aliasResp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{ "name": aliasName, "canonical_id": entityID, "mount_accessor": mountAccessor, @@ -48,7 +48,12 @@ func createEntityAndAlias(client *api.Client, mountAccessor, entityName, aliasNa if err != nil { t.Fatalf("failed to create an entity alias:%v", err) } - return userClient, entityID + + aliasID := aliasResp.Data["id"].(string) + if aliasID == "" { + t.Fatal("Alias ID not present in response") + } + return userClient, entityID, aliasID } func registerEntityInTOTPEngine(client *api.Client, entityID, methodID string, t *testing.T) string { @@ -162,8 +167,8 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { } // Creating two users in the userpass auth mount - userClient1, entityID1 := createEntityAndAlias(client, mountAccessor, "entity1", "testuser1", t) - userClient2, entityID2 := createEntityAndAlias(client, mountAccessor, "entity2", "testuser2", t) + userClient1, entityID1, _ := createEntityAndAlias(client, mountAccessor, "entity1", "testuser1", t) + userClient2, entityID2, _ := createEntityAndAlias(client, mountAccessor, "entity2", "testuser2", t) // configure TOTP secret engine var methodID string diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go index cc85d0179..11db95d8b 100644 --- a/vault/identity_store_entities.go +++ b/vault/identity_store_entities.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" + "github.com/hashicorp/go-multierror" + "github.com/golang/protobuf/ptypes" memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-secure-stdlib/strutil" @@ -126,12 +128,16 @@ func entityPaths(i *IdentityStore) []*framework.Path { Fields: map[string]*framework.FieldSchema{ "from_entity_ids": { Type: framework.TypeCommaStringSlice, - Description: "Entity IDs which needs to get merged", + Description: "Entity IDs which need to get merged", }, "to_entity_id": { Type: framework.TypeString, Description: "Entity ID into which all the other entities need to get merged", }, + "conflicting_alias_ids_to_keep": { + Type: framework.TypeCommaStringSlice, + Description: "Alias IDs to keep in case of conflicting aliases. Ignored if no conflicting aliases found", + }, "force": { Type: framework.TypeBool, Description: "Setting this will follow the 'mine' strategy for merging MFA secrets. If there are secrets of the same type both in entities that are merged from and in entity into which all others are getting merged, secrets in the destination will be unaltered. If not set, this API will throw an error containing all the conflicts.", @@ -150,17 +156,27 @@ func entityPaths(i *IdentityStore) []*framework.Path { // pathEntityMergeID merges two or more entities into a single entity func (i *IdentityStore) pathEntityMergeID() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - toEntityID := d.Get("to_entity_id").(string) - if toEntityID == "" { + toEntityIDInterface, ok := d.GetOk("to_entity_id") + if !ok || toEntityIDInterface == "" { return logical.ErrorResponse("missing entity id to merge to"), nil } + toEntityID := toEntityIDInterface.(string) - fromEntityIDs := d.Get("from_entity_ids").([]string) - if len(fromEntityIDs) == 0 { + fromEntityIDsInterface, ok := d.GetOk("from_entity_ids") + if !ok || len(fromEntityIDsInterface.([]string)) == 0 { return logical.ErrorResponse("missing entity ids to merge from"), nil } + fromEntityIDs := fromEntityIDsInterface.([]string) - force := d.Get("force").(bool) + var conflictingAliasIDsToKeep []string + if conflictingAliasIDsToKeepInterface, ok := d.GetOk("conflicting_alias_ids_to_keep"); ok { + conflictingAliasIDsToKeep = conflictingAliasIDsToKeepInterface.([]string) + } + + var force bool + if forceInterface, ok := d.GetOk("force"); ok { + force = forceInterface.(bool) + } // Create a MemDB transaction to merge entities i.lock.Lock() @@ -174,7 +190,7 @@ func (i *IdentityStore) pathEntityMergeID() framework.OperationFunc { return nil, err } - userErr, intErr := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, force, false, false, true) + userErr, intErr := i.mergeEntity(ctx, txn, toEntity, fromEntityIDs, conflictingAliasIDsToKeep, force, false, false, true, false) if userErr != nil { return logical.ErrorResponse(userErr.Error()), nil } @@ -717,7 +733,11 @@ func (i *IdentityStore) handlePathEntityListCommon(ctx context.Context, req *log return logical.ListResponseWithInfo(keys, entityInfo), nil } -func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs []string, force, grabLock, mergePolicies, persist bool) (error, error) { +func (i *IdentityStore) mergeEntityAsPartOfUpsert(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityID string, persist bool) (error, error) { + return i.mergeEntity(ctx, txn, toEntity, []string{fromEntityID}, []string{}, true, false, true, persist, true) +} + +func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntity *identity.Entity, fromEntityIDs, conflictingAliasIDsToKeep []string, force, grabLock, mergePolicies, persist, forceMergeAliases bool) (error, error) { if grabLock { i.lock.Lock() defer i.lock.Unlock() @@ -735,9 +755,18 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit return errors.New("entity id to merge into does not belong to the request's namespace"), nil } + if len(fromEntityIDs) > 1 && len(conflictingAliasIDsToKeep) > 1 { + return errors.New("aliases conflicts cannot be resolved with multiple from entity ids - merge one entity at a time"), nil + } + sanitizedFromEntityIDs := strutil.RemoveDuplicates(fromEntityIDs, false) - // Merge the MFA secrets + // A map to check if there are any clashes between mount accessors for any of the sanitizedFromEntityIDs + fromEntityAccessors := make(map[string]string) + + // An error detailing if any alias clashes happen (shared mount accessor) + var aliasClashError error + for _, fromEntityID := range sanitizedFromEntityIDs { if fromEntityID == toEntity.ID { return errors.New("to_entity_id should not be present in from_entity_ids"), nil @@ -756,6 +785,32 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit return errors.New("entity id to merge from does not belong to this namespace"), nil } + // If we're not resolving a conflict, we check to see if + // any aliases conflict between the toEntity and this fromEntity: + if !forceMergeAliases && len(conflictingAliasIDsToKeep) == 0 { + for _, toAlias := range toEntity.Aliases { + for _, fromAlias := range fromEntity.Aliases { + // First, check to see if this alias clashes with an alias from any of the other fromEntities: + id, mountAccessorInAnotherFromEntity := fromEntityAccessors[fromAlias.MountAccessor] + if mountAccessorInAnotherFromEntity && (id != fromEntityID) { + return fmt.Errorf("mount accessor %s found in multiple fromEntities, merge should be done with one fromEntity at a time", fromAlias.MountAccessor), nil + } + + fromEntityAccessors[fromAlias.MountAccessor] = fromEntityID + + // If it doesn't, check if it clashes with the toEntities + if toAlias.MountAccessor == fromAlias.MountAccessor { + if aliasClashError == nil { + aliasClashError = multierror.Append(aliasClashError, fmt.Errorf("toEntity and at least one fromEntity have aliases with the same mount accessor, repeat the merge request specifying exactly one fromEntity, clashes: ")) + } + aliasClashError = multierror.Append(aliasClashError, + fmt.Errorf("mountAccessor: %s, toEntity ID: %s, fromEntity ID: %s, conflicting toEntity alias ID: %s, conflicting fromEntity alias ID: %s", + toAlias.MountAccessor, toEntity.ID, fromEntityID, toAlias.ID, fromAlias.ID)) + } + } + } + } + for configID, configSecret := range fromEntity.MFASecrets { _, ok := toEntity.MFASecrets[configID] if ok && !force { @@ -769,15 +824,26 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit } } + // Check alias clashes after validating every fromEntity, so that we have a full list of errors + if aliasClashError != nil { + return aliasClashError, nil + } + isPerfSecondaryOrStandby := i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby var fromEntityGroups []*identity.Group - toEntityAccessors := make(map[string]struct{}) + toEntityAccessors := make(map[string][]string) for _, alias := range toEntity.Aliases { - if _, ok := toEntityAccessors[alias.MountAccessor]; !ok { - toEntityAccessors[alias.MountAccessor] = struct{}{} + if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok { + // While it is not supported to have multiple aliases with the same mount accessor in one entity + // we do not strictly enforce the invariant. Thus, we account for multiple just to be safe + if accessors == nil { + toEntityAccessors[alias.MountAccessor] = []string{alias.ID} + } else { + toEntityAccessors[alias.MountAccessor] = append(accessors, alias.ID) + } } } @@ -799,23 +865,53 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit return errors.New("entity id to merge from does not belong to this namespace"), nil } - for _, alias := range fromEntity.Aliases { + for _, fromAlias := range fromEntity.Aliases { + // If true, we need to handle conflicts (conflict = both aliases share the same mount accessor) + if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok { + for _, toAliasId := range toAliasIds { + // When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision + // for the user, and keep the to_entity alias, merging the from_entity + // This case's code is the same as when the user selects to keep the from_entity alias + // but is kept separate for clarity + if forceMergeAliases { + i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) + err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) + if err != nil { + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + } + } else if strutil.StrListContains(conflictingAliasIDsToKeep, toAliasId) { + i.logger.Info("Deleting from_entity alias during entity merge", "from_entity", fromEntityID, "deleted_alias", fromAlias.ID) + err := i.MemDBDeleteAliasByIDInTxn(txn, fromAlias.ID, false) + if err != nil { + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + } + + // Continue to next alias, as there's no alias to merge left in the from_entity + continue + } else if strutil.StrListContains(conflictingAliasIDsToKeep, fromAlias.ID) { + i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) + err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) + if err != nil { + return nil, fmt.Errorf("failed to delete orphaned alias during merge: %w", err) + } + } else { + return fmt.Errorf("conflicting mount accessors in following alias IDs and neither were present in conflicting_alias_ids_to_keep: %s, %s", fromAlias.ID, toAliasId), nil + } + } + } + // Set the desired canonical ID - alias.CanonicalID = toEntity.ID + fromAlias.CanonicalID = toEntity.ID - alias.MergedFromCanonicalIDs = append(alias.MergedFromCanonicalIDs, fromEntity.ID) + fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID) - err = i.MemDBUpsertAliasInTxn(txn, alias, false) + err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false) if err != nil { return nil, fmt.Errorf("failed to update alias during merge: %w", err) } - if _, ok := toEntityAccessors[alias.MountAccessor]; ok { - i.logger.Warn("skipping from_entity alias during entity merge as to_entity has an alias with its accessor", "from_entity", fromEntityID, "skipped_alias", alias.ID) - continue - } // Add the alias to the desired entity - toEntity.Aliases = append(toEntity.Aliases, alias) + toEntity.Aliases = append(toEntity.Aliases, fromAlias) } // If told to, merge policies diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go index 2f7139cea..52462832c 100644 --- a/vault/identity_store_entities_test.go +++ b/vault/identity_store_entities_test.go @@ -1008,12 +1008,6 @@ func TestIdentityStore_MergeEntitiesByID(t *testing.T) { aliasRegisterData2 := map[string]interface{}{ "name": "testaliasname2", - "mount_accessor": githubAccessor, - "metadata": []string{"organization=hashicorp", "team=vault"}, - } - - aliasRegisterData3 := map[string]interface{}{ - "name": "testaliasname3", "mount_accessor": upAccessor, "metadata": []string{"organization=hashicorp", "team=vault"}, } @@ -1079,24 +1073,10 @@ func TestIdentityStore_MergeEntitiesByID(t *testing.T) { } entityID2 := resp.Data["id"].(string) - // Set entity ID in alias registration data and register alias + aliasRegisterData2["entity_id"] = entityID2 - aliasReq = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "alias", - Data: aliasRegisterData2, - } - - // Register the alias - resp, err = is.HandleRequest(ctx, aliasReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - - aliasRegisterData3["entity_id"] = entityID2 - - aliasReq.Data = aliasRegisterData3 + aliasReq.Data = aliasRegisterData2 // Register the alias resp, err = is.HandleRequest(ctx, aliasReq) @@ -1111,8 +1091,8 @@ func TestIdentityStore_MergeEntitiesByID(t *testing.T) { t.Fatalf("failed to create entity: %v", err) } - if len(entity2.Aliases) != 2 { - t.Fatalf("bad: number of aliases in entity; expected: 2, actual: %d", len(entity2.Aliases)) + if len(entity2.Aliases) != 1 { + t.Fatalf("bad: number of aliases in entity; expected: 1, actual: %d", len(entity2.Aliases)) } entity2GroupReq := &logical.Request{ diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index a5e335f94..550035d5e 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -537,7 +537,6 @@ func TestIdentityStore_MergeConflictingAliases(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - c, _, _ := TestCoreUnsealed(t) meGH := &MountEntry{ diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index ae08bd387..2eeef336f 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -595,7 +595,7 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e default: i.logger.Warn("alias is already tied to a different entity; these entities are being merged", "alias_id", alias.ID, "other_entity_id", aliasByFactors.CanonicalID, "entity_aliases", entity.Aliases, "alias_by_factors", aliasByFactors) - respErr, intErr := i.mergeEntity(ctx, txn, entity, []string{aliasByFactors.CanonicalID}, true, false, true, persist) + respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persist) switch { case respErr != nil: return respErr @@ -604,7 +604,7 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e } // The entity and aliases will be loaded into memdb and persisted - // as a result of the merge so we are done here + // as a result of the merge, so we are done here return nil }