package vault import ( "context" "errors" "fmt" "strings" "sync" "time" metrics "github.com/armon/go-metrics" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/errwrap" memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-secure-stdlib/strutil" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/storagepacker" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" ) var ( errDuplicateIdentityName = errors.New("duplicate identity name") errCycleDetectedPrefix = "cyclic relationship detected for member group ID" tmpSuffix = ".tmp" ) func (c *Core) SetLoadCaseSensitiveIdentityStore(caseSensitive bool) { c.loadCaseSensitiveIdentityStore = caseSensitive } func (c *Core) loadIdentityStoreArtifacts(ctx context.Context) error { if c.identityStore == nil { c.logger.Warn("identity store is not setup, skipping loading") return nil } loadFunc := func(context.Context) error { if err := c.identityStore.loadEntities(ctx); err != nil { return err } if err := c.identityStore.loadGroups(ctx); err != nil { return err } if err := c.identityStore.loadOIDCClients(ctx); err != nil { return err } if err := c.identityStore.loadCachedEntitiesOfLocalAliases(ctx); err != nil { return err } return nil } if !c.loadCaseSensitiveIdentityStore { // Load everything when memdb is set to operate on lower cased names err := loadFunc(ctx) switch { case err == nil: // If it succeeds, all is well return nil case !errwrap.Contains(err, errDuplicateIdentityName.Error()): return err } } c.identityStore.logger.Warn("enabling case sensitive identity names") // Set identity store to operate on case sensitive identity names c.identityStore.disableLowerCasedNames = true // Swap the memdb instance by the one which operates on case sensitive // names, hence obviating the need to unload anything that's already // loaded. if err := c.identityStore.resetDB(ctx); err != nil { return err } // Attempt to load identity artifacts once more after memdb is reset to // accept case sensitive names return loadFunc(ctx) } func (i *IdentityStore) sanitizeName(name string) string { if i.disableLowerCasedNames { return name } return strings.ToLower(name) } func (i *IdentityStore) loadGroups(ctx context.Context) error { i.logger.Debug("identity loading groups") existing, err := i.groupPacker.View().List(ctx, groupBucketsPrefix) if err != nil { return fmt.Errorf("failed to scan for groups: %w", err) } i.logger.Debug("groups collected", "num_existing", len(existing)) for _, key := range existing { bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key) if err != nil { return err } if bucket == nil { continue } for _, item := range bucket.Items { group, err := i.parseGroupFromBucketItem(item) if err != nil { return err } if group == nil { continue } ns, err := i.namespacer.NamespaceByID(ctx, group.NamespaceID) if err != nil { return err } if ns == nil { // Remove dangling groups if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) { // Group's namespace doesn't exist anymore but the group // from the namespace still exists. i.logger.Warn("deleting group and its any existing aliases", "name", group.Name, "namespace_id", group.NamespaceID) err = i.groupPacker.DeleteItem(ctx, group.ID) if err != nil { return err } } continue } nsCtx := namespace.ContextWithNamespace(ctx, ns) // Ensure that there are no groups with duplicate names groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false) if err != nil { return err } if groupByName != nil { i.logger.Warn(errDuplicateIdentityName.Error(), "group_name", group.Name, "conflicting_group_name", groupByName.Name, "action", "merge the contents of duplicated groups into one and delete the other") if !i.disableLowerCasedNames { return errDuplicateIdentityName } } if i.logger.IsDebug() { i.logger.Debug("loading group", "name", group.Name, "id", group.ID) } txn := i.db.Txn(true) // Before pull#5786, entity memberships in groups were not getting // updated when respective entities were deleted. This is here to // check that the entity IDs in the group are indeed valid, and if // not remove them. persist := false for _, memberEntityID := range group.MemberEntityIDs { entity, err := i.MemDBEntityByID(memberEntityID, false) if err != nil { txn.Abort() return err } if entity == nil { persist = true group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, memberEntityID) } } err = i.UpsertGroupInTxn(ctx, txn, group, persist) if err != nil { txn.Abort() return fmt.Errorf("failed to update group in memdb: %w", err) } txn.Commit() } } if i.logger.IsInfo() { i.logger.Info("groups restored") } return nil } func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) error { // If we are performance secondary, load from temporary location those // entities that were created by the secondary via RPCs to the primary, and // also happen to have not yet been shipped to the secondary through // performance replication. if !i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) { return nil } i.logger.Debug("loading cached entities of local aliases") existing, err := i.localAliasPacker.View().List(ctx, localAliasesBucketsPrefix) if err != nil { return fmt.Errorf("failed to scan for cached entities of local alias: %w", err) } i.logger.Debug("cached entities of local alias entries", "num_buckets", len(existing)) // Make the channels used for the worker pool broker := make(chan string) quit := make(chan bool) // Buffer these channels to prevent deadlocks errs := make(chan error, len(existing)) result := make(chan *storagepacker.Bucket, len(existing)) // Use a wait group wg := &sync.WaitGroup{} // Create 64 workers to distribute work to for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ { wg.Add(1) go func() { defer wg.Done() for { select { case key, ok := <-broker: // broker has been closed, we are done if !ok { return } bucket, err := i.localAliasPacker.GetBucket(ctx, localAliasesBucketsPrefix+key) if err != nil { errs <- err continue } // Write results out to the result channel result <- bucket // quit early case <-quit: return } } }() } // Distribute the collected keys to the workers in a go routine wg.Add(1) go func() { defer wg.Done() for j, key := range existing { if j%500 == 0 { i.logger.Debug("cached entities of local aliases loading", "progress", j) } select { case <-quit: return default: broker <- key } } // Close the broker, causing worker routines to exit close(broker) }() defer func() { // Let all go routines finish wg.Wait() i.logger.Info("cached entities of local aliases restored") }() // Restore each key by pulling from the result chan for j := 0; j < len(existing); j++ { select { case err := <-errs: // Close all go routines close(quit) return err case bucket := <-result: // If there is no entry, nothing to restore if bucket == nil { continue } for _, item := range bucket.Items { if !strings.HasSuffix(item.ID, tmpSuffix) { continue } entity, err := i.parseCachedEntity(item) if err != nil { return err } ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID) if err != nil { return err } nsCtx := namespace.ContextWithNamespace(ctx, ns) err = i.upsertEntity(nsCtx, entity, nil, false) if err != nil { return fmt.Errorf("failed to update entity in MemDB: %w", err) } } } } return nil } func (i *IdentityStore) loadEntities(ctx context.Context) error { // Accumulate existing entities i.logger.Debug("loading entities") existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix) if err != nil { return fmt.Errorf("failed to scan for entities: %w", err) } i.logger.Debug("entities collected", "num_existing", len(existing)) duplicatedAccessors := make(map[string]struct{}) // Make the channels used for the worker pool broker := make(chan string) quit := make(chan bool) // Buffer these channels to prevent deadlocks errs := make(chan error, len(existing)) result := make(chan *storagepacker.Bucket, len(existing)) // Use a wait group wg := &sync.WaitGroup{} // Create 64 workers to distribute work to for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ { wg.Add(1) go func() { defer wg.Done() for { select { case key, ok := <-broker: // broker has been closed, we are done if !ok { return } bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key) if err != nil { errs <- err continue } // Write results out to the result channel result <- bucket // quit early case <-quit: return } } }() } // Distribute the collected keys to the workers in a go routine wg.Add(1) go func() { defer wg.Done() for j, key := range existing { if j%500 == 0 { i.logger.Debug("entities loading", "progress", j) } select { case <-quit: return default: broker <- key } } // Close the broker, causing worker routines to exit close(broker) }() // Restore each key by pulling from the result chan LOOP: for j := 0; j < len(existing); j++ { select { case err = <-errs: // Close all go routines close(quit) break LOOP case bucket := <-result: // If there is no entry, nothing to restore if bucket == nil { continue } for _, item := range bucket.Items { entity, err := i.parseEntityFromBucketItem(ctx, item) if err != nil { return err } if entity == nil { continue } ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID) if err != nil { return err } if ns == nil { // Remove dangling entities if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) { // Entity's namespace doesn't exist anymore but the // entity from the namespace still exists. i.logger.Warn("deleting entity and its any existing aliases", "name", entity.Name, "namespace_id", entity.NamespaceID) err = i.entityPacker.DeleteItem(ctx, entity.ID) if err != nil { return err } } continue } nsCtx := namespace.ContextWithNamespace(ctx, ns) // Ensure that there are no entities with duplicate names entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false) if err != nil { return nil } if entityByName != nil { i.logger.Warn(errDuplicateIdentityName.Error(), "entity_name", entity.Name, "conflicting_entity_name", entityByName.Name, "action", "merge the duplicate entities into one") if !i.disableLowerCasedNames { return errDuplicateIdentityName } } mountAccessors := getAccessorsOnDuplicateAliases(entity.Aliases) for _, accessor := range mountAccessors { if _, ok := duplicatedAccessors[accessor]; !ok { duplicatedAccessors[accessor] = struct{}{} } } localAliases, err := i.parseLocalAliases(entity.ID) if err != nil { return fmt.Errorf("failed to load local aliases from storage: %v", err) } if localAliases != nil { for _, alias := range localAliases.Aliases { entity.UpsertAlias(alias) } } // Only update MemDB and don't hit the storage again err = i.upsertEntity(nsCtx, entity, nil, false) if err != nil { return fmt.Errorf("failed to update entity in MemDB: %w", err) } } } } // Let all go routines finish wg.Wait() if err != nil { return err } // Flatten the map into a list of keys, in order to log them duplicatedAccessorsList := make([]string, len(duplicatedAccessors)) accessorCounter := 0 for accessor := range duplicatedAccessors { duplicatedAccessorsList[accessorCounter] = accessor accessorCounter++ } if len(duplicatedAccessorsList) > 0 { i.logger.Warn("One or more entities have multiple aliases on the same mount(s), remove duplicates to avoid ACL templating issues", "mount_accessors", duplicatedAccessorsList) } if i.logger.IsInfo() { i.logger.Info("entities restored") } return nil } // getAccessorsOnDuplicateAliases returns a list of accessors by checking aliases in // the passed in list which belong to the same accessor(s) func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string { accessorCounts := make(map[string]int) var mountAccessors []string for _, alias := range aliases { accessorCounts[alias.MountAccessor] += 1 } for accessor, accessorCount := range accessorCounts { if accessorCount > 1 { mountAccessors = append(mountAccessors, accessor) } } return mountAccessors } // upsertEntityInTxn either creates or updates an existing entity. The // operations will be updated in both MemDB and storage. If 'persist' is set to // false, then storage will not be updated. When an alias is transferred from // one entity to another, both the source and destination entities should get // updated, in which case, callers should send in both entity and // previousEntity. func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now()) var err error if txn == nil { return errors.New("txn is nil") } if entity == nil { return errors.New("entity is nil") } if entity.NamespaceID == "" { entity.NamespaceID = namespace.RootNamespaceID } if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID { return errors.New("entity and previous entity are not in the same namespace") } aliasFactors := make([]string, len(entity.Aliases)) for index, alias := range entity.Aliases { // Verify that alias is not associated to a different one already aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false) if err != nil { return err } if alias.NamespaceID == "" { alias.NamespaceID = namespace.RootNamespaceID } switch { case aliasByFactors == nil: // Not found, no merging needed, just check namespace if alias.NamespaceID != entity.NamespaceID { return errors.New("alias and entity are not in the same namespace") } case aliasByFactors.CanonicalID == entity.ID: // Lookup found the same entity, so it's already attached to the // right place if aliasByFactors.NamespaceID != entity.NamespaceID { return errors.New("alias from factors and entity are not in the same namespace") } case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID: // previousEntity isn't upserted yet so may still contain the old // alias reference in memdb if it was just changed; validate // whether or not it's _actually_ still tied to the entity var found bool for _, prevEntAlias := range previousEntity.Aliases { if prevEntAlias.ID == alias.ID { found = true break } } // If we didn't find the alias still tied to previousEntity, we // shouldn't use the merging logic and should bail if !found { break } // Otherwise it's still tied to previousEntity and fall through // into merging. We don't need a namespace check here as existing // checks when creating the aliases should ensure that all line up. fallthrough default: i.logger.Warn("alias is already tied to a different entity; these entities are being merged", "alias_id", alias.ID, "other_entity_id", aliasByFactors.CanonicalID, "entity_aliases", entity.Aliases, "alias_by_factors", aliasByFactors) respErr, intErr := i.mergeEntity(ctx, txn, entity, []string{aliasByFactors.CanonicalID}, true, false, true, persist) switch { case respErr != nil: return respErr case intErr != nil: return intErr } // The entity and aliases will be loaded into memdb and persisted // as a result of the merge so we are done here return nil } if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) { i.logger.Warn(errDuplicateIdentityName.Error(), "alias_name", alias.Name, "mount_accessor", alias.MountAccessor, "local", alias.Local, "entity_name", entity.Name, "action", "delete one of the duplicate aliases") if !i.disableLowerCasedNames { return errDuplicateIdentityName } } // Insert or update alias in MemDB using the transaction created above err = i.MemDBUpsertAliasInTxn(txn, alias, false) if err != nil { return err } aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor } // If previous entity is set, update it in MemDB and persist it if previousEntity != nil { err = i.MemDBUpsertEntityInTxn(txn, previousEntity) if err != nil { return err } if persist { // Persist the previous entity object if err := i.persistEntity(ctx, previousEntity); err != nil { return err } } } // Insert or update entity in MemDB using the transaction created above err = i.MemDBUpsertEntityInTxn(txn, entity) if err != nil { return err } if persist { if err := i.persistEntity(ctx, entity); err != nil { return err } } return nil } func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) { if !lAlias.Local { return nil, fmt.Errorf("alias is not local") } mountValidationResp := i.router.ValidateMountByAccessor(lAlias.MountAccessor) if mountValidationResp == nil { return nil, fmt.Errorf("invalid mount accessor %q", lAlias.MountAccessor) } if !mountValidationResp.MountLocal { return nil, fmt.Errorf("mount accessor %q is not local", lAlias.MountAccessor) } alias, err := i.MemDBAliasByFactors(lAlias.MountAccessor, lAlias.Name, true, false) if err != nil { return nil, err } if alias == nil { alias = &identity.Alias{} } alias.CanonicalID = entity.ID alias.Name = lAlias.Name alias.MountAccessor = lAlias.MountAccessor alias.Metadata = lAlias.Metadata alias.MountPath = mountValidationResp.MountPath alias.MountType = mountValidationResp.MountType alias.Local = lAlias.Local alias.CustomMetadata = lAlias.CustomMetadata if err := i.sanitizeAlias(ctx, alias); err != nil { return nil, err } entity.UpsertAlias(alias) localAliases, err := i.parseLocalAliases(entity.ID) if err != nil { return nil, err } if localAliases == nil { localAliases = &identity.LocalAliases{} } updated := false for i, item := range localAliases.Aliases { if item.ID == alias.ID { localAliases.Aliases[i] = alias updated = true break } } if !updated { localAliases.Aliases = append(localAliases.Aliases, alias) } marshaledAliases, err := ptypes.MarshalAny(localAliases) if err != nil { return nil, err } if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{ ID: entity.ID, Message: marshaledAliases, }); err != nil { return nil, err } if updateDb { txn := i.db.Txn(true) defer txn.Abort() if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil { return nil, err } if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil { return nil, err } txn.Commit() } return alias, nil } // cacheTemporaryEntity stores in secondary's storage, the entity returned by // the primary cluster via the CreateEntity RPC. This is so that the secondary // cluster knows and retains information about the existence of these entities // before the replication invalidation informs the secondary of the same. This // also happens to cover the case where the secondary's replication is lagging // behind the primary by hours and/or days which sometimes may happen. Even if // the nodes of the secondary are restarted in the interim, the cluster would // still be aware of the entities. This temporary cache will be cleared when the // invalidation hits the secondary nodes. func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identity.Entity) error { if i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) && i.localNode.HAState() != consts.PerfStandby { marshaledEntity, err := ptypes.MarshalAny(entity) if err != nil { return err } if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{ ID: entity.ID + tmpSuffix, Message: marshaledEntity, }); err != nil { return err } } return nil } func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error { // If the entity that is passed into this function is resulting from a memdb // query without cloning, then modifying it will result in a direct DB edit, // bypassing the transaction. To avoid any surprises arising from this // effect, work on a replica of the entity struct. var err error entity, err = entity.Clone() if err != nil { return err } // Separate the local and non-local aliases. var localAliases []*identity.Alias var nonLocalAliases []*identity.Alias for _, alias := range entity.Aliases { switch alias.Local { case true: localAliases = append(localAliases, alias) default: nonLocalAliases = append(nonLocalAliases, alias) } } // Store the entity with non-local aliases. entity.Aliases = nonLocalAliases marshaledEntity, err := ptypes.MarshalAny(entity) if err != nil { return err } if err := i.entityPacker.PutItem(ctx, &storagepacker.Item{ ID: entity.ID, Message: marshaledEntity, }); err != nil { return err } if len(localAliases) == 0 { return nil } // Store the local aliases separately. aliases := &identity.LocalAliases{ Aliases: localAliases, } marshaledAliases, err := ptypes.MarshalAny(aliases) if err != nil { return err } if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{ ID: entity.ID, Message: marshaledAliases, }); err != nil { return err } return nil } // upsertEntity either creates or updates an existing entity. The operations // will be updated in both MemDB and storage. If 'persist' is set to false, // then storage will not be updated. When an alias is transferred from one // entity to another, both the source and destination entities should get // updated, in which case, callers should send in both entity and // previousEntity. func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { defer metrics.MeasureSince([]string{"identity", "upsert_entity"}, time.Now()) // Create a MemDB transaction to update both alias and entity txn := i.db.Txn(true) defer txn.Abort() err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist) if err != nil { return err } txn.Commit() return nil } func (i *IdentityStore) MemDBUpsertAliasInTxn(txn *memdb.Txn, alias *identity.Alias, groupAlias bool) error { if txn == nil { return fmt.Errorf("nil txn") } if alias == nil { return fmt.Errorf("alias is nil") } if alias.NamespaceID == "" { alias.NamespaceID = namespace.RootNamespaceID } tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } aliasRaw, err := txn.First(tableName, "id", alias.ID) if err != nil { return fmt.Errorf("failed to lookup alias from memdb using alias ID: %w", err) } if aliasRaw != nil { err = txn.Delete(tableName, aliasRaw) if err != nil { return fmt.Errorf("failed to delete alias from memdb: %w", err) } } if err := txn.Insert(tableName, alias); err != nil { return fmt.Errorf("failed to update alias into memdb: %w", err) } return nil } func (i *IdentityStore) MemDBAliasByIDInTxn(txn *memdb.Txn, aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } if txn == nil { return nil, fmt.Errorf("txn is nil") } tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } aliasRaw, err := txn.First(tableName, "id", aliasID) if err != nil { return nil, fmt.Errorf("failed to fetch alias from memdb using alias ID: %w", err) } if aliasRaw == nil { return nil, nil } alias, ok := aliasRaw.(*identity.Alias) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched alias") } if clone { return alias.Clone() } return alias, nil } func (i *IdentityStore) MemDBAliasByID(aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } txn := i.db.Txn(false) return i.MemDBAliasByIDInTxn(txn, aliasID, clone, groupAlias) } func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) { if aliasName == "" { return nil, fmt.Errorf("missing alias name") } if mountAccessor == "" { return nil, fmt.Errorf("missing mount accessor") } txn := i.db.Txn(false) return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias) } func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) { if txn == nil { return nil, fmt.Errorf("nil txn") } if aliasName == "" { return nil, fmt.Errorf("missing alias name") } if mountAccessor == "" { return nil, fmt.Errorf("missing mount accessor") } tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName) if err != nil { return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %w", err) } if aliasRaw == nil { return nil, nil } alias, ok := aliasRaw.(*identity.Alias) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched alias") } if clone { return alias.Clone() } return alias, nil } func (i *IdentityStore) MemDBDeleteAliasByIDInTxn(txn *memdb.Txn, aliasID string, groupAlias bool) error { if aliasID == "" { return nil } if txn == nil { return fmt.Errorf("txn is nil") } alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, groupAlias) if err != nil { return err } if alias == nil { return nil } tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } err = txn.Delete(tableName, alias) if err != nil { return fmt.Errorf("failed to delete alias from memdb: %w", err) } return nil } func (i *IdentityStore) MemDBAliases(ws memdb.WatchSet, groupAlias bool) (memdb.ResultIterator, error) { txn := i.db.Txn(false) tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } iter, err := txn.Get(tableName, "id") if err != nil { return nil, err } ws.Add(iter.WatchCh()) return iter, nil } func (i *IdentityStore) MemDBUpsertEntityInTxn(txn *memdb.Txn, entity *identity.Entity) error { if txn == nil { return fmt.Errorf("nil txn") } if entity == nil { return fmt.Errorf("entity is nil") } if entity.NamespaceID == "" { entity.NamespaceID = namespace.RootNamespaceID } entityRaw, err := txn.First(entitiesTable, "id", entity.ID) if err != nil { return fmt.Errorf("failed to lookup entity from memdb using entity id: %w", err) } if entityRaw != nil { err = txn.Delete(entitiesTable, entityRaw) if err != nil { return fmt.Errorf("failed to delete entity from memdb: %w", err) } } if err := txn.Insert(entitiesTable, entity); err != nil { return fmt.Errorf("failed to update entity into memdb: %w", err) } return nil } func (i *IdentityStore) MemDBEntityByIDInTxn(txn *memdb.Txn, entityID string, clone bool) (*identity.Entity, error) { if entityID == "" { return nil, fmt.Errorf("missing entity id") } if txn == nil { return nil, fmt.Errorf("txn is nil") } entityRaw, err := txn.First(entitiesTable, "id", entityID) if err != nil { return nil, fmt.Errorf("failed to fetch entity from memdb using entity id: %w", err) } if entityRaw == nil { return nil, nil } entity, ok := entityRaw.(*identity.Entity) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched entity") } if clone { return entity.Clone() } return entity, nil } func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) { if entityID == "" { return nil, fmt.Errorf("missing entity id") } txn := i.db.Txn(false) return i.MemDBEntityByIDInTxn(txn, entityID, clone) } func (i *IdentityStore) MemDBEntityByName(ctx context.Context, entityName string, clone bool) (*identity.Entity, error) { if entityName == "" { return nil, fmt.Errorf("missing entity name") } txn := i.db.Txn(false) return i.MemDBEntityByNameInTxn(ctx, txn, entityName, clone) } func (i *IdentityStore) MemDBEntityByNameInTxn(ctx context.Context, txn *memdb.Txn, entityName string, clone bool) (*identity.Entity, error) { if entityName == "" { return nil, fmt.Errorf("missing entity name") } ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } entityRaw, err := txn.First(entitiesTable, "name", ns.ID, entityName) if err != nil { return nil, fmt.Errorf("failed to fetch entity from memdb using entity name: %w", err) } if entityRaw == nil { return nil, nil } entity, ok := entityRaw.(*identity.Entity) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched entity") } if clone { return entity.Clone() } return entity, nil } func (i *IdentityStore) MemDBLocalAliasesByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Alias, error) { if txn == nil { return nil, fmt.Errorf("nil txn") } if bucketKey == "" { return nil, fmt.Errorf("empty bucket key") } iter, err := txn.Get(entityAliasesTable, "local_bucket_key", bucketKey) if err != nil { return nil, fmt.Errorf("failed to lookup aliases using local bucket entry key hash: %w", err) } var aliases []*identity.Alias for item := iter.Next(); item != nil; item = iter.Next() { alias := item.(*identity.Alias) if alias.Local { aliases = append(aliases, alias) } } return aliases, nil } func (i *IdentityStore) MemDBEntitiesByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Entity, error) { if txn == nil { return nil, fmt.Errorf("nil txn") } if bucketKey == "" { return nil, fmt.Errorf("empty bucket key") } entitiesIter, err := txn.Get(entitiesTable, "bucket_key", bucketKey) if err != nil { return nil, fmt.Errorf("failed to lookup entities using bucket entry key hash: %w", err) } var entities []*identity.Entity for item := entitiesIter.Next(); item != nil; item = entitiesIter.Next() { entity, err := item.(*identity.Entity).Clone() if err != nil { return nil, err } entities = append(entities, entity) } return entities, nil } func (i *IdentityStore) MemDBEntityByMergedEntityID(mergedEntityID string, clone bool) (*identity.Entity, error) { if mergedEntityID == "" { return nil, fmt.Errorf("missing merged entity id") } txn := i.db.Txn(false) entityRaw, err := txn.First(entitiesTable, "merged_entity_ids", mergedEntityID) if err != nil { return nil, fmt.Errorf("failed to fetch entity from memdb using merged entity id: %w", err) } if entityRaw == nil { return nil, nil } entity, ok := entityRaw.(*identity.Entity) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched entity") } if clone { return entity.Clone() } return entity, nil } func (i *IdentityStore) MemDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Entity, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } if txn == nil { return nil, fmt.Errorf("txn is nil") } alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, false) if err != nil { return nil, err } if alias == nil { return nil, nil } return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, clone) } func (i *IdentityStore) MemDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } txn := i.db.Txn(false) return i.MemDBEntityByAliasIDInTxn(txn, aliasID, clone) } func (i *IdentityStore) MemDBDeleteEntityByID(entityID string) error { if entityID == "" { return nil } txn := i.db.Txn(true) defer txn.Abort() err := i.MemDBDeleteEntityByIDInTxn(txn, entityID) if err != nil { return err } txn.Commit() return nil } func (i *IdentityStore) MemDBDeleteEntityByIDInTxn(txn *memdb.Txn, entityID string) error { if entityID == "" { return nil } if txn == nil { return fmt.Errorf("txn is nil") } entity, err := i.MemDBEntityByIDInTxn(txn, entityID, false) if err != nil { return err } if entity == nil { return nil } err = txn.Delete(entitiesTable, entity) if err != nil { return fmt.Errorf("failed to delete entity from memdb: %w", err) } return nil } func (i *IdentityStore) sanitizeAlias(ctx context.Context, alias *identity.Alias) error { var err error if alias == nil { return fmt.Errorf("alias is nil") } // Alias must always be tied to a canonical object if alias.CanonicalID == "" { return fmt.Errorf("missing canonical ID") } // Alias must have a name if alias.Name == "" { return fmt.Errorf("missing alias name %q", alias.Name) } // Alias metadata should always be map[string]string err = validateMetadata(alias.Metadata) if err != nil { return fmt.Errorf("invalid alias metadata: %w", err) } // Create an ID if there isn't one already if alias.ID == "" { alias.ID, err = uuid.GenerateUUID() if err != nil { return fmt.Errorf("failed to generate alias ID") } alias.LocalBucketKey = i.localAliasPacker.BucketKey(alias.CanonicalID) } if alias.NamespaceID == "" { ns, err := namespace.FromContext(ctx) if err != nil { return err } alias.NamespaceID = ns.ID } ns, err := namespace.FromContext(ctx) if err != nil { return err } if ns.ID != alias.NamespaceID { return errors.New("alias belongs to a different namespace") } // Set the creation and last update times if alias.CreationTime == nil { alias.CreationTime = ptypes.TimestampNow() alias.LastUpdateTime = alias.CreationTime } else { alias.LastUpdateTime = ptypes.TimestampNow() } return nil } func (i *IdentityStore) sanitizeEntity(ctx context.Context, entity *identity.Entity) error { var err error if entity == nil { return fmt.Errorf("entity is nil") } // Create an ID if there isn't one already if entity.ID == "" { entity.ID, err = uuid.GenerateUUID() if err != nil { return fmt.Errorf("failed to generate entity id") } // Set the storage bucket key in entity entity.BucketKey = i.entityPacker.BucketKey(entity.ID) } ns, err := namespace.FromContext(ctx) if err != nil { return err } if entity.NamespaceID == "" { entity.NamespaceID = ns.ID } if ns.ID != entity.NamespaceID { return errors.New("entity does not belong to this namespace") } // Create a name if there isn't one already if entity.Name == "" { entity.Name, err = i.generateName(ctx, "entity") if err != nil { return fmt.Errorf("failed to generate entity name") } } // Entity metadata should always be map[string]string err = validateMetadata(entity.Metadata) if err != nil { return fmt.Errorf("invalid entity metadata: %w", err) } // Set the creation and last update times if entity.CreationTime == nil { entity.CreationTime = ptypes.TimestampNow() entity.LastUpdateTime = entity.CreationTime } else { entity.LastUpdateTime = ptypes.TimestampNow() } // Ensure that MFASecrets is non-nil at any time. This is useful when MFA // secret generation procedures try to append MFA info to entity. if entity.MFASecrets == nil { entity.MFASecrets = make(map[string]*mfa.Secret) } return nil } func (i *IdentityStore) sanitizeAndUpsertGroup(ctx context.Context, group *identity.Group, previousGroup *identity.Group, memberGroupIDs []string) error { var err error if group == nil { return fmt.Errorf("group is nil") } // Create an ID if there isn't one already if group.ID == "" { group.ID, err = uuid.GenerateUUID() if err != nil { return fmt.Errorf("failed to generate group id") } // Set the hash value of the storage bucket key in group group.BucketKey = i.groupPacker.BucketKey(group.ID) } if group.NamespaceID == "" { ns, err := namespace.FromContext(ctx) if err != nil { return err } group.NamespaceID = ns.ID } ns, err := namespace.FromContext(ctx) if err != nil { return err } if ns.ID != group.NamespaceID { return errors.New("group does not belong to this namespace") } // Create a name if there isn't one already if group.Name == "" { group.Name, err = i.generateName(ctx, "group") if err != nil { return fmt.Errorf("failed to generate group name") } } // Entity metadata should always be map[string]string err = validateMetadata(group.Metadata) if err != nil { return fmt.Errorf("invalid group metadata: %w", err) } // Set the creation and last update times if group.CreationTime == nil { group.CreationTime = ptypes.TimestampNow() group.LastUpdateTime = group.CreationTime } else { group.LastUpdateTime = ptypes.TimestampNow() } // Remove duplicate entity IDs and check if all IDs are valid group.MemberEntityIDs = strutil.RemoveDuplicates(group.MemberEntityIDs, false) for _, entityID := range group.MemberEntityIDs { entity, err := i.MemDBEntityByID(entityID, false) if err != nil { return fmt.Errorf("failed to validate entity ID %q: %w", entityID, err) } if entity == nil { return fmt.Errorf("invalid entity ID %q", entityID) } } // Remove duplicate policies group.Policies = strutil.RemoveDuplicates(group.Policies, false) txn := i.db.Txn(true) defer txn.Abort() var currentMemberGroupIDs []string var currentMemberGroups []*identity.Group // If there are no member group IDs supplied, then it shouldn't be // processed. If an empty set of member group IDs are supplied, then it // should be processed. Hence the nil check instead of the length check. if memberGroupIDs == nil { goto ALIAS } memberGroupIDs = strutil.RemoveDuplicates(memberGroupIDs, false) // For those group member IDs that are removed from the list, remove current // group ID as their respective ParentGroupID. // Get the current MemberGroups IDs for this group currentMemberGroups, err = i.MemDBGroupsByParentGroupID(group.ID, false) if err != nil { return err } for _, currentMemberGroup := range currentMemberGroups { currentMemberGroupIDs = append(currentMemberGroupIDs, currentMemberGroup.ID) } // Update parent group IDs in the removed members for _, currentMemberGroupID := range currentMemberGroupIDs { if strutil.StrListContains(memberGroupIDs, currentMemberGroupID) { continue } currentMemberGroup, err := i.MemDBGroupByID(currentMemberGroupID, true) if err != nil { return err } if currentMemberGroup == nil { return fmt.Errorf("invalid member group ID %q", currentMemberGroupID) } // Remove group ID from the parent group IDs currentMemberGroup.ParentGroupIDs = strutil.StrListDelete(currentMemberGroup.ParentGroupIDs, group.ID) err = i.UpsertGroupInTxn(ctx, txn, currentMemberGroup, true) if err != nil { return err } } // After the group lock is held, make membership updates to all the // relevant groups for _, memberGroupID := range memberGroupIDs { memberGroup, err := i.MemDBGroupByID(memberGroupID, true) if err != nil { return err } if memberGroup == nil { return fmt.Errorf("invalid member group ID %q", memberGroupID) } // Skip if memberGroupID is already a member of group.ID if strutil.StrListContains(memberGroup.ParentGroupIDs, group.ID) { continue } // Ensure that adding memberGroupID does not lead to cyclic // relationships // Detect self loop if group.ID == memberGroupID { return fmt.Errorf("member group ID %q is same as the ID of the group", group.ID) } groupByID, err := i.MemDBGroupByID(group.ID, true) if err != nil { return err } // If group is nil, that means that a group doesn't already exist and its // okay to add any group as its member group. if groupByID != nil { // If adding the memberGroupID to groupID creates a cycle, then groupID must // be a hop in that loop. Start a DFS traversal from memberGroupID and see if // it reaches back to groupID. If it does, then it's a loop. // Created a visited set visited := make(map[string]bool) cycleDetected, err := i.detectCycleDFS(visited, groupByID.ID, memberGroupID) if err != nil { return fmt.Errorf("failed to perform cyclic relationship detection for member group ID %q", memberGroupID) } if cycleDetected { return fmt.Errorf("%s %q", errCycleDetectedPrefix, memberGroupID) } } memberGroup.ParentGroupIDs = append(memberGroup.ParentGroupIDs, group.ID) // This technically is not upsert. It is only update, only the method // name is upsert here. err = i.UpsertGroupInTxn(ctx, txn, memberGroup, true) if err != nil { // Ideally we would want to revert the whole operation in case of // errors while persisting in member groups. But there is no // storage transaction support yet. When we do have it, this will need // an update. return err } } ALIAS: // Sanitize the group alias if group.Alias != nil { group.Alias.CanonicalID = group.ID err = i.sanitizeAlias(ctx, group.Alias) if err != nil { return err } } // If previousGroup is not nil, we are moving the alias from the previous // group to the new one. As a result we need to upsert both in the context // of this same transaction. if previousGroup != nil { err = i.UpsertGroupInTxn(ctx, txn, previousGroup, true) if err != nil { return err } } err = i.UpsertGroupInTxn(ctx, txn, group, true) if err != nil { return err } txn.Commit() return nil } func (i *IdentityStore) deleteAliasesInEntityInTxn(txn *memdb.Txn, entity *identity.Entity, aliases []*identity.Alias) error { if entity == nil { return fmt.Errorf("entity is nil") } if txn == nil { return fmt.Errorf("txn is nil") } var remainList []*identity.Alias var removeList []*identity.Alias for _, item := range entity.Aliases { remove := false for _, alias := range aliases { if alias.ID == item.ID { remove = true } } if remove { removeList = append(removeList, item) } else { remainList = append(remainList, item) } } // Remove identity indices from aliases table for those that needs to // be removed for _, alias := range removeList { err := i.MemDBDeleteAliasByIDInTxn(txn, alias.ID, false) if err != nil { return err } } // Update the entity with remaining items entity.Aliases = remainList return nil } // validateMeta validates a set of key/value pairs from the agent config func validateMetadata(meta map[string]string) error { if len(meta) > metaMaxKeyPairs { return fmt.Errorf("metadata cannot contain more than %d key/value pairs", metaMaxKeyPairs) } for key, value := range meta { if err := validateMetaPair(key, value); err != nil { return fmt.Errorf("failed to load metadata pair (%q, %q): %w", key, value, err) } } return nil } // validateMetaPair checks that the given key/value pair is in a valid format func validateMetaPair(key, value string) error { if key == "" { return fmt.Errorf("key cannot be blank") } if !metaKeyFormatRegEx(key) { return fmt.Errorf("key contains invalid characters") } if len(key) > metaKeyMaxLength { return fmt.Errorf("key is too long (limit: %d characters)", metaKeyMaxLength) } if strings.HasPrefix(key, metaKeyReservedPrefix) { return fmt.Errorf("key prefix %q is reserved for internal use", metaKeyReservedPrefix) } if len(value) > metaValueMaxLength { return fmt.Errorf("value is too long (limit: %d characters)", metaValueMaxLength) } return nil } func (i *IdentityStore) MemDBGroupByNameInTxn(ctx context.Context, txn *memdb.Txn, groupName string, clone bool) (*identity.Group, error) { if groupName == "" { return nil, fmt.Errorf("missing group name") } if txn == nil { return nil, fmt.Errorf("txn is nil") } ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } groupRaw, err := txn.First(groupsTable, "name", ns.ID, groupName) if err != nil { return nil, fmt.Errorf("failed to fetch group from memdb using group name: %w", err) } if groupRaw == nil { return nil, nil } group, ok := groupRaw.(*identity.Group) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched group") } if clone { return group.Clone() } return group, nil } func (i *IdentityStore) MemDBGroupByName(ctx context.Context, groupName string, clone bool) (*identity.Group, error) { if groupName == "" { return nil, fmt.Errorf("missing group name") } txn := i.db.Txn(false) return i.MemDBGroupByNameInTxn(ctx, txn, groupName, clone) } func (i *IdentityStore) UpsertGroup(ctx context.Context, group *identity.Group, persist bool) error { defer metrics.MeasureSince([]string{"identity", "upsert_group"}, time.Now()) txn := i.db.Txn(true) defer txn.Abort() err := i.UpsertGroupInTxn(ctx, txn, group, true) if err != nil { return err } txn.Commit() return nil } func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, group *identity.Group, persist bool) error { defer metrics.MeasureSince([]string{"identity", "upsert_group_txn"}, time.Now()) var err error if txn == nil { return fmt.Errorf("txn is nil") } if group == nil { return fmt.Errorf("group is nil") } // Increment the modify index of the group group.ModifyIndex++ // Clear the old alias from memdb groupClone, err := i.MemDBGroupByID(group.ID, true) if err != nil { return err } if groupClone != nil && groupClone.Alias != nil { err = i.MemDBDeleteAliasByIDInTxn(txn, groupClone.Alias.ID, true) if err != nil { return err } } // Add the new alias to memdb if group.Alias != nil { err = i.MemDBUpsertAliasInTxn(txn, group.Alias, true) if err != nil { return err } } // Insert or update group in MemDB using the transaction created above err = i.MemDBUpsertGroupInTxn(txn, group) if err != nil { return err } if persist { groupAsAny, err := ptypes.MarshalAny(group) if err != nil { return err } item := &storagepacker.Item{ ID: group.ID, Message: groupAsAny, } sent, err := i.groupUpdater.SendGroupUpdate(ctx, group) if err != nil { return err } if !sent { if err := i.groupPacker.PutItem(ctx, item); err != nil { return err } } } return nil } func (i *IdentityStore) MemDBUpsertGroupInTxn(txn *memdb.Txn, group *identity.Group) error { if txn == nil { return fmt.Errorf("nil txn") } if group == nil { return fmt.Errorf("group is nil") } if group.NamespaceID == "" { group.NamespaceID = namespace.RootNamespaceID } groupRaw, err := txn.First(groupsTable, "id", group.ID) if err != nil { return fmt.Errorf("failed to lookup group from memdb using group id: %w", err) } if groupRaw != nil { err = txn.Delete(groupsTable, groupRaw) if err != nil { return fmt.Errorf("failed to delete group from memdb: %w", err) } } if err := txn.Insert(groupsTable, group); err != nil { return fmt.Errorf("failed to update group into memdb: %w", err) } return nil } func (i *IdentityStore) MemDBDeleteGroupByIDInTxn(txn *memdb.Txn, groupID string) error { if groupID == "" { return nil } if txn == nil { return fmt.Errorf("txn is nil") } group, err := i.MemDBGroupByIDInTxn(txn, groupID, false) if err != nil { return err } if group == nil { return nil } err = txn.Delete("groups", group) if err != nil { return fmt.Errorf("failed to delete group from memdb: %w", err) } return nil } func (i *IdentityStore) MemDBGroupByIDInTxn(txn *memdb.Txn, groupID string, clone bool) (*identity.Group, error) { if groupID == "" { return nil, fmt.Errorf("missing group ID") } if txn == nil { return nil, fmt.Errorf("txn is nil") } groupRaw, err := txn.First(groupsTable, "id", groupID) if err != nil { return nil, fmt.Errorf("failed to fetch group from memdb using group ID: %w", err) } if groupRaw == nil { return nil, nil } group, ok := groupRaw.(*identity.Group) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched group") } if clone { return group.Clone() } return group, nil } func (i *IdentityStore) MemDBGroupByID(groupID string, clone bool) (*identity.Group, error) { if groupID == "" { return nil, fmt.Errorf("missing group ID") } txn := i.db.Txn(false) return i.MemDBGroupByIDInTxn(txn, groupID, clone) } func (i *IdentityStore) MemDBGroupsByParentGroupIDInTxn(txn *memdb.Txn, memberGroupID string, clone bool) ([]*identity.Group, error) { if memberGroupID == "" { return nil, fmt.Errorf("missing member group ID") } groupsIter, err := txn.Get(groupsTable, "parent_group_ids", memberGroupID) if err != nil { return nil, fmt.Errorf("failed to lookup groups using member group ID: %w", err) } var groups []*identity.Group for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { entry := group.(*identity.Group) if clone { entry, err = entry.Clone() if err != nil { return nil, err } } groups = append(groups, entry) } return groups, nil } func (i *IdentityStore) MemDBGroupsByParentGroupID(memberGroupID string, clone bool) ([]*identity.Group, error) { if memberGroupID == "" { return nil, fmt.Errorf("missing member group ID") } txn := i.db.Txn(false) return i.MemDBGroupsByParentGroupIDInTxn(txn, memberGroupID, clone) } func (i *IdentityStore) MemDBGroupsByMemberEntityID(entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) { txn := i.db.Txn(false) defer txn.Abort() return i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, clone, externalOnly) } func (i *IdentityStore) MemDBGroupsByMemberEntityIDInTxn(txn *memdb.Txn, entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) { if entityID == "" { return nil, fmt.Errorf("missing entity ID") } groupsIter, err := txn.Get(groupsTable, "member_entity_ids", entityID) if err != nil { return nil, fmt.Errorf("failed to lookup groups using entity ID: %w", err) } var groups []*identity.Group for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { entry := group.(*identity.Group) if externalOnly && entry.Type == groupTypeInternal { continue } if clone { entry, err = entry.Clone() if err != nil { return nil, err } } groups = append(groups, entry) } return groups, nil } func (i *IdentityStore) groupPoliciesByEntityID(entityID string) (map[string][]string, error) { if entityID == "" { return nil, fmt.Errorf("empty entity ID") } groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false) if err != nil { return nil, err } visited := make(map[string]bool) policies := make(map[string][]string) for _, group := range groups { err := i.collectPoliciesReverseDFS(group, visited, policies) if err != nil { return nil, err } } return policies, nil } func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) { if entityID == "" { return nil, nil, fmt.Errorf("empty entity ID") } groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false) if err != nil { return nil, nil, err } visited := make(map[string]bool) var tGroups []*identity.Group for _, group := range groups { gGroups, err := i.collectGroupsReverseDFS(group, visited, nil) if err != nil { return nil, nil, err } tGroups = append(tGroups, gGroups...) } // Remove duplicates groupMap := make(map[string]*identity.Group) for _, group := range tGroups { groupMap[group.ID] = group } tGroups = make([]*identity.Group, 0, len(groupMap)) for _, group := range groupMap { tGroups = append(tGroups, group) } diff := diffGroups(groups, tGroups) // For sanity // There should not be any group that gets deleted if len(diff.Deleted) != 0 { return nil, nil, fmt.Errorf("failed to diff group memberships") } return diff.Unmodified, diff.New, nil } func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) { if group == nil { return nil, fmt.Errorf("nil group") } // If traversal for a groupID is performed before, skip it if visited[group.ID] { return groups, nil } visited[group.ID] = true groups = append(groups, group) // Traverse all the parent groups for _, parentGroupID := range group.ParentGroupIDs { parentGroup, err := i.MemDBGroupByID(parentGroupID, false) if err != nil { return nil, err } if parentGroup == nil { continue } groups, err = i.collectGroupsReverseDFS(parentGroup, visited, groups) if err != nil { return nil, fmt.Errorf("failed to collect group at parent group ID %q", parentGroup.ID) } } return groups, nil } func (i *IdentityStore) collectPoliciesReverseDFS(group *identity.Group, visited map[string]bool, policies map[string][]string) error { if group == nil { return fmt.Errorf("nil group") } // If traversal for a groupID is performed before, skip it if visited[group.ID] { return nil } visited[group.ID] = true policies[group.NamespaceID] = append(policies[group.NamespaceID], group.Policies...) // Traverse all the parent groups for _, parentGroupID := range group.ParentGroupIDs { parentGroup, err := i.MemDBGroupByID(parentGroupID, false) if err != nil { return err } if parentGroup == nil { continue } err = i.collectPoliciesReverseDFS(parentGroup, visited, policies) if err != nil { return fmt.Errorf("failed to collect policies at parent group ID %q", parentGroup.ID) } } return nil } func (i *IdentityStore) detectCycleDFS(visited map[string]bool, startingGroupID, groupID string) (bool, error) { // If the traversal reaches the startingGroupID, a loop is detected if startingGroupID == groupID { return true, nil } // If traversal for a groupID is performed before, skip it if visited[groupID] { return false, nil } visited[groupID] = true group, err := i.MemDBGroupByID(groupID, true) if err != nil { return false, err } if group == nil { return false, nil } // Fetch all groups in which groupID is present as a ParentGroupID. In // other words, find all the subgroups of groupID. memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false) if err != nil { return false, err } // DFS traverse the member groups for _, memberGroup := range memberGroups { cycleDetected, err := i.detectCycleDFS(visited, startingGroupID, memberGroup.ID) if err != nil { return false, fmt.Errorf("failed to perform cycle detection at member group ID %q", memberGroup.ID) } if cycleDetected { return true, nil } } return false, nil } func (i *IdentityStore) memberGroupIDsByID(groupID string) ([]string, error) { var memberGroupIDs []string memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false) if err != nil { return nil, err } for _, memberGroup := range memberGroups { memberGroupIDs = append(memberGroupIDs, memberGroup.ID) } return memberGroupIDs, nil } func (i *IdentityStore) generateName(ctx context.Context, entryType string) (string, error) { var name string OUTER: for { randBytes, err := uuid.GenerateRandomBytes(4) if err != nil { return "", err } name = fmt.Sprintf("%s_%s", entryType, fmt.Sprintf("%08x", randBytes[0:4])) switch entryType { case "entity": entity, err := i.MemDBEntityByName(ctx, name, false) if err != nil { return "", err } if entity == nil { break OUTER } case "group": group, err := i.MemDBGroupByName(ctx, name, false) if err != nil { return "", err } if group == nil { break OUTER } default: return "", fmt.Errorf("unrecognized type %q", entryType) } } return name, nil } func (i *IdentityStore) MemDBGroupsByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Group, error) { if txn == nil { return nil, fmt.Errorf("nil txn") } if bucketKey == "" { return nil, fmt.Errorf("empty bucket key") } groupsIter, err := txn.Get(groupsTable, "bucket_key", bucketKey) if err != nil { return nil, fmt.Errorf("failed to lookup groups using bucket entry key hash: %w", err) } var groups []*identity.Group for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { groups = append(groups, group.(*identity.Group)) } return groups, nil } func (i *IdentityStore) MemDBGroupByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Group, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } if txn == nil { return nil, fmt.Errorf("txn is nil") } alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, true) if err != nil { return nil, err } if alias == nil { return nil, nil } return i.MemDBGroupByIDInTxn(txn, alias.CanonicalID, clone) } func (i *IdentityStore) MemDBGroupByAliasID(aliasID string, clone bool) (*identity.Group, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID") } txn := i.db.Txn(false) return i.MemDBGroupByAliasIDInTxn(txn, aliasID, clone) } func (i *IdentityStore) refreshExternalGroupMembershipsByEntityID(ctx context.Context, entityID string, groupAliases []*logical.Alias, mountAccessor string) ([]*logical.Alias, error) { defer metrics.MeasureSince([]string{"identity", "refresh_external_groups"}, time.Now()) if entityID == "" { return nil, fmt.Errorf("empty entity ID") } refreshFunc := func(dryRun bool) (bool, []*logical.Alias, error) { if !dryRun { i.groupLock.Lock() defer i.groupLock.Unlock() } txn := i.db.Txn(!dryRun) defer txn.Abort() oldGroups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, true, true) if err != nil { return false, nil, err } var newGroups []*identity.Group var validAliases []*logical.Alias for _, alias := range groupAliases { aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, true, true) if err != nil { return false, nil, err } if aliasByFactors == nil { continue } mappingGroup, err := i.MemDBGroupByAliasIDInTxn(txn, aliasByFactors.ID, true) if err != nil { return false, nil, err } if mappingGroup == nil { return false, nil, fmt.Errorf("group unavailable for a valid alias ID %q", aliasByFactors.ID) } newGroups = append(newGroups, mappingGroup) validAliases = append(validAliases, alias) } diff := diffGroups(oldGroups, newGroups) // Add the entity ID to all the new groups for _, group := range diff.New { if group.Type != groupTypeExternal { continue } // We need to update a group, if we are in a dry run we should // report back that a change needs to take place. if dryRun { return true, nil, nil } i.logger.Debug("adding member entity ID to external group", "member_entity_id", entityID, "group_id", group.ID) group.MemberEntityIDs = append(group.MemberEntityIDs, entityID) err = i.UpsertGroupInTxn(ctx, txn, group, true) if err != nil { return false, nil, err } } // Remove the entity ID from all the deleted groups for _, group := range diff.Deleted { if group.Type != groupTypeExternal { continue } // If the external group is from a different mount, don't remove the // entity ID from it. if mountAccessor != "" && group.Alias != nil && group.Alias.MountAccessor != mountAccessor { continue } // We need to update a group, if we are in a dry run we should // report back that a change needs to take place. if dryRun { return true, nil, nil } i.logger.Debug("removing member entity ID from external group", "member_entity_id", entityID, "group_id", group.ID) group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, entityID) err = i.UpsertGroupInTxn(ctx, txn, group, true) if err != nil { return false, nil, err } } txn.Commit() return false, validAliases, nil } // dryRun needsUpdate, validAliases, err := refreshFunc(true) if err != nil { return nil, err } if needsUpdate || len(groupAliases) > 0 { i.logger.Debug("refreshing external group memberships", "entity_id", entityID, "group_aliases", groupAliases) } if !needsUpdate { return validAliases, nil } // Run the update _, validAliases, err = refreshFunc(false) if err != nil { return nil, err } return validAliases, nil } // diffGroups is used to diff two sets of groups func diffGroups(old, new []*identity.Group) *groupDiff { diff := &groupDiff{} existing := make(map[string]*identity.Group) for _, group := range old { existing[group.ID] = group } for _, group := range new { // Check if the entry in new is present in the old _, ok := existing[group.ID] // If its not present, then its a new entry if !ok { diff.New = append(diff.New, group) continue } // If its present, it means that its unmodified diff.Unmodified = append(diff.Unmodified, group) // By deleting the unmodified from the old set, we could determine the // ones that are stale by looking at the remaining ones. delete(existing, group.ID) } // Any remaining entries must have been deleted for _, me := range existing { diff.Deleted = append(diff.Deleted, me) } return diff } func (i *IdentityStore) handleAliasListCommon(ctx context.Context, groupAlias bool) (*logical.Response, error) { ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } ws := memdb.NewWatchSet() txn := i.db.Txn(false) iter, err := txn.Get(tableName, "namespace_id", ns.ID) if err != nil { return nil, fmt.Errorf("failed to fetch iterator for aliases in memdb: %w", err) } ws.Add(iter.WatchCh()) var aliasIDs []string aliasInfo := map[string]interface{}{} type mountInfo struct { MountType string MountPath string } mountAccessorMap := map[string]mountInfo{} for { raw := iter.Next() if raw == nil { break } alias := raw.(*identity.Alias) aliasIDs = append(aliasIDs, alias.ID) aliasInfoEntry := map[string]interface{}{ "name": alias.Name, "canonical_id": alias.CanonicalID, "mount_accessor": alias.MountAccessor, "custom_metadata": alias.CustomMetadata, "local": alias.Local, } mi, ok := mountAccessorMap[alias.MountAccessor] if ok { aliasInfoEntry["mount_type"] = mi.MountType aliasInfoEntry["mount_path"] = mi.MountPath } else { mi = mountInfo{} if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil { mi.MountType = mountValidationResp.MountType mi.MountPath = mountValidationResp.MountPath aliasInfoEntry["mount_type"] = mi.MountType aliasInfoEntry["mount_path"] = mi.MountPath } mountAccessorMap[alias.MountAccessor] = mi } aliasInfo[alias.ID] = aliasInfoEntry } return logical.ListResponseWithInfo(aliasIDs, aliasInfo), nil } func (i *IdentityStore) countEntities() (int, error) { txn := i.db.Txn(false) iter, err := txn.Get(entitiesTable, "id") if err != nil { return -1, err } count := 0 val := iter.Next() for val != nil { count++ val = iter.Next() } return count, nil } // Sum up the number of entities belonging to each namespace (keyed by ID) func (i *IdentityStore) countEntitiesByNamespace(ctx context.Context) (map[string]int, error) { txn := i.db.Txn(false) iter, err := txn.Get(entitiesTable, "id") if err != nil { return nil, err } byNamespace := make(map[string]int) val := iter.Next() for val != nil { // Check if runtime exceeded. select { case <-ctx.Done(): return byNamespace, errors.New("context cancelled") default: break } // Count in the namespace attached to the entity. entity := val.(*identity.Entity) byNamespace[entity.NamespaceID] = byNamespace[entity.NamespaceID] + 1 val = iter.Next() } return byNamespace, nil } // Sum up the number of entities belonging to each mount point (keyed by accessor) func (i *IdentityStore) countEntitiesByMountAccessor(ctx context.Context) (map[string]int, error) { txn := i.db.Txn(false) iter, err := txn.Get(entitiesTable, "id") if err != nil { return nil, err } byMountAccessor := make(map[string]int) val := iter.Next() for val != nil { // Check if runtime exceeded. select { case <-ctx.Done(): return byMountAccessor, errors.New("context cancelled") default: break } // Count each alias separately; will translate to mount point and type // in the caller. entity := val.(*identity.Entity) for _, alias := range entity.Aliases { byMountAccessor[alias.MountAccessor] = byMountAccessor[alias.MountAccessor] + 1 } val = iter.Next() } return byMountAccessor, nil }