2594 lines
65 KiB
Go
2594 lines
65 KiB
Go
package vault
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
"github.com/golang/protobuf/ptypes"
|
|
"github.com/hashicorp/errwrap"
|
|
memdb "github.com/hashicorp/go-memdb"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
uuid "github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/helper/identity"
|
|
"github.com/hashicorp/vault/helper/identity/mfa"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/helper/storagepacker"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
var (
|
|
errDuplicateIdentityName = errors.New("duplicate identity name")
|
|
tmpSuffix = ".tmp"
|
|
)
|
|
|
|
func (c *Core) SetLoadCaseSensitiveIdentityStore(caseSensitive bool) {
|
|
c.loadCaseSensitiveIdentityStore = caseSensitive
|
|
}
|
|
|
|
func (c *Core) loadIdentityStoreArtifacts(ctx context.Context) error {
|
|
if c.identityStore == nil {
|
|
c.logger.Warn("identity store is not setup, skipping loading")
|
|
return nil
|
|
}
|
|
|
|
loadFunc := func(context.Context) error {
|
|
if err := c.identityStore.loadEntities(ctx); err != nil {
|
|
return err
|
|
}
|
|
if err := c.identityStore.loadGroups(ctx); err != nil {
|
|
return err
|
|
}
|
|
if err := c.identityStore.loadOIDCClients(ctx); err != nil {
|
|
return err
|
|
}
|
|
if err := c.identityStore.loadCachedEntitiesOfLocalAliases(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
if !c.loadCaseSensitiveIdentityStore {
|
|
// Load everything when memdb is set to operate on lower cased names
|
|
err := loadFunc(ctx)
|
|
switch {
|
|
case err == nil:
|
|
// If it succeeds, all is well
|
|
return nil
|
|
case !errwrap.Contains(err, errDuplicateIdentityName.Error()):
|
|
return err
|
|
}
|
|
}
|
|
|
|
c.identityStore.logger.Warn("enabling case sensitive identity names")
|
|
|
|
// Set identity store to operate on case sensitive identity names
|
|
c.identityStore.disableLowerCasedNames = true
|
|
|
|
// Swap the memdb instance by the one which operates on case sensitive
|
|
// names, hence obviating the need to unload anything that's already
|
|
// loaded.
|
|
if err := c.identityStore.resetDB(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Attempt to load identity artifacts once more after memdb is reset to
|
|
// accept case sensitive names
|
|
return loadFunc(ctx)
|
|
}
|
|
|
|
func (i *IdentityStore) sanitizeName(name string) string {
|
|
if i.disableLowerCasedNames {
|
|
return name
|
|
}
|
|
return strings.ToLower(name)
|
|
}
|
|
|
|
func (i *IdentityStore) loadOIDCClients(ctx context.Context) error {
|
|
i.logger.Debug("identity loading OIDC clients")
|
|
|
|
clients, err := i.view.List(ctx, clientPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
for _, name := range clients {
|
|
entry, err := i.view.Get(ctx, clientPath+name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if entry == nil {
|
|
continue
|
|
}
|
|
|
|
var client client
|
|
if err := entry.DecodeJSON(&client); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := i.memDBUpsertClientInTxn(txn, &client); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
txn.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) loadGroups(ctx context.Context) error {
|
|
i.logger.Debug("identity loading groups")
|
|
existing, err := i.groupPacker.View().List(ctx, groupBucketsPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan for groups: %w", err)
|
|
}
|
|
i.logger.Debug("groups collected", "num_existing", len(existing))
|
|
|
|
for _, key := range existing {
|
|
bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if bucket == nil {
|
|
continue
|
|
}
|
|
|
|
for _, item := range bucket.Items {
|
|
group, err := i.parseGroupFromBucketItem(item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if group == nil {
|
|
continue
|
|
}
|
|
|
|
ns, err := i.namespacer.NamespaceByID(ctx, group.NamespaceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ns == nil {
|
|
// Remove dangling groups
|
|
if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) {
|
|
// Group's namespace doesn't exist anymore but the group
|
|
// from the namespace still exists.
|
|
i.logger.Warn("deleting group and its any existing aliases", "name", group.Name, "namespace_id", group.NamespaceID)
|
|
err = i.groupPacker.DeleteItem(ctx, group.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
nsCtx := namespace.ContextWithNamespace(ctx, ns)
|
|
|
|
// Ensure that there are no groups with duplicate names
|
|
groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if groupByName != nil {
|
|
i.logger.Warn(errDuplicateIdentityName.Error(), "group_name", group.Name, "conflicting_group_name", groupByName.Name, "action", "merge the contents of duplicated groups into one and delete the other")
|
|
if !i.disableLowerCasedNames {
|
|
return errDuplicateIdentityName
|
|
}
|
|
}
|
|
|
|
if i.logger.IsDebug() {
|
|
i.logger.Debug("loading group", "name", group.Name, "id", group.ID)
|
|
}
|
|
|
|
txn := i.db.Txn(true)
|
|
|
|
// Before pull#5786, entity memberships in groups were not getting
|
|
// updated when respective entities were deleted. This is here to
|
|
// check that the entity IDs in the group are indeed valid, and if
|
|
// not remove them.
|
|
persist := false
|
|
for _, memberEntityID := range group.MemberEntityIDs {
|
|
entity, err := i.MemDBEntityByID(memberEntityID, false)
|
|
if err != nil {
|
|
txn.Abort()
|
|
return err
|
|
}
|
|
if entity == nil {
|
|
persist = true
|
|
group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, memberEntityID)
|
|
}
|
|
}
|
|
|
|
err = i.UpsertGroupInTxn(ctx, txn, group, persist)
|
|
if err != nil {
|
|
txn.Abort()
|
|
return fmt.Errorf("failed to update group in memdb: %w", err)
|
|
}
|
|
|
|
txn.Commit()
|
|
}
|
|
}
|
|
|
|
if i.logger.IsInfo() {
|
|
i.logger.Info("groups restored")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) error {
|
|
// If we are performance secondary, load from temporary location those
|
|
// entities that were created by the secondary via RPCs to the primary, and
|
|
// also happen to have not yet been shipped to the secondary through
|
|
// performance replication.
|
|
if !i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) {
|
|
return nil
|
|
}
|
|
|
|
i.logger.Debug("loading cached entities of local aliases")
|
|
existing, err := i.localAliasPacker.View().List(ctx, localAliasesBucketsPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan for cached entities of local alias: %w", err)
|
|
}
|
|
|
|
i.logger.Debug("cached entities of local alias entries", "num_buckets", len(existing))
|
|
|
|
// Make the channels used for the worker pool
|
|
broker := make(chan string)
|
|
quit := make(chan bool)
|
|
|
|
// Buffer these channels to prevent deadlocks
|
|
errs := make(chan error, len(existing))
|
|
result := make(chan *storagepacker.Bucket, len(existing))
|
|
|
|
// Use a wait group
|
|
wg := &sync.WaitGroup{}
|
|
|
|
// Create 64 workers to distribute work to
|
|
for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case key, ok := <-broker:
|
|
// broker has been closed, we are done
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
bucket, err := i.localAliasPacker.GetBucket(ctx, localAliasesBucketsPrefix+key)
|
|
if err != nil {
|
|
errs <- err
|
|
continue
|
|
}
|
|
|
|
// Write results out to the result channel
|
|
result <- bucket
|
|
|
|
// quit early
|
|
case <-quit:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Distribute the collected keys to the workers in a go routine
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for j, key := range existing {
|
|
if j%500 == 0 {
|
|
i.logger.Debug("cached entities of local aliases loading", "progress", j)
|
|
}
|
|
|
|
select {
|
|
case <-quit:
|
|
return
|
|
|
|
default:
|
|
broker <- key
|
|
}
|
|
}
|
|
|
|
// Close the broker, causing worker routines to exit
|
|
close(broker)
|
|
}()
|
|
|
|
// Restore each key by pulling from the result chan
|
|
for j := 0; j < len(existing); j++ {
|
|
select {
|
|
case err := <-errs:
|
|
// Close all go routines
|
|
close(quit)
|
|
|
|
return err
|
|
|
|
case bucket := <-result:
|
|
// If there is no entry, nothing to restore
|
|
if bucket == nil {
|
|
continue
|
|
}
|
|
|
|
for _, item := range bucket.Items {
|
|
if !strings.HasSuffix(item.ID, tmpSuffix) {
|
|
continue
|
|
}
|
|
entity, err := i.parseCachedEntity(item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
nsCtx := namespace.ContextWithNamespace(ctx, ns)
|
|
|
|
err = i.upsertEntity(nsCtx, entity, nil, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update entity in MemDB: %w", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Let all go routines finish
|
|
wg.Wait()
|
|
|
|
if i.logger.IsInfo() {
|
|
i.logger.Info("cached entities of local aliases restored")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) loadEntities(ctx context.Context) error {
|
|
// Accumulate existing entities
|
|
i.logger.Debug("loading entities")
|
|
existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan for entities: %w", err)
|
|
}
|
|
i.logger.Debug("entities collected", "num_existing", len(existing))
|
|
|
|
duplicatedAccessors := make(map[string]struct{})
|
|
// Make the channels used for the worker pool
|
|
broker := make(chan string)
|
|
quit := make(chan bool)
|
|
|
|
// Buffer these channels to prevent deadlocks
|
|
errs := make(chan error, len(existing))
|
|
result := make(chan *storagepacker.Bucket, len(existing))
|
|
|
|
// Use a wait group
|
|
wg := &sync.WaitGroup{}
|
|
|
|
// Create 64 workers to distribute work to
|
|
for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case key, ok := <-broker:
|
|
// broker has been closed, we are done
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key)
|
|
if err != nil {
|
|
errs <- err
|
|
continue
|
|
}
|
|
|
|
// Write results out to the result channel
|
|
result <- bucket
|
|
|
|
// quit early
|
|
case <-quit:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Distribute the collected keys to the workers in a go routine
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for j, key := range existing {
|
|
if j%500 == 0 {
|
|
i.logger.Debug("entities loading", "progress", j)
|
|
}
|
|
|
|
select {
|
|
case <-quit:
|
|
return
|
|
|
|
default:
|
|
broker <- key
|
|
}
|
|
}
|
|
|
|
// Close the broker, causing worker routines to exit
|
|
close(broker)
|
|
}()
|
|
|
|
// Restore each key by pulling from the result chan
|
|
for j := 0; j < len(existing); j++ {
|
|
select {
|
|
case err := <-errs:
|
|
// Close all go routines
|
|
close(quit)
|
|
|
|
return err
|
|
|
|
case bucket := <-result:
|
|
// If there is no entry, nothing to restore
|
|
if bucket == nil {
|
|
continue
|
|
}
|
|
|
|
for _, item := range bucket.Items {
|
|
entity, err := i.parseEntityFromBucketItem(ctx, item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if entity == nil {
|
|
continue
|
|
}
|
|
|
|
ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ns == nil {
|
|
// Remove dangling entities
|
|
if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) {
|
|
// Entity's namespace doesn't exist anymore but the
|
|
// entity from the namespace still exists.
|
|
i.logger.Warn("deleting entity and its any existing aliases", "name", entity.Name, "namespace_id", entity.NamespaceID)
|
|
err = i.entityPacker.DeleteItem(ctx, entity.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
nsCtx := namespace.ContextWithNamespace(ctx, ns)
|
|
|
|
// Ensure that there are no entities with duplicate names
|
|
entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
if entityByName != nil {
|
|
i.logger.Warn(errDuplicateIdentityName.Error(), "entity_name", entity.Name, "conflicting_entity_name", entityByName.Name, "action", "merge the duplicate entities into one")
|
|
if !i.disableLowerCasedNames {
|
|
return errDuplicateIdentityName
|
|
}
|
|
}
|
|
|
|
mountAccessors := getAccessorsOnDuplicateAliases(entity.Aliases)
|
|
|
|
for _, accessor := range mountAccessors {
|
|
if _, ok := duplicatedAccessors[accessor]; !ok {
|
|
duplicatedAccessors[accessor] = struct{}{}
|
|
}
|
|
}
|
|
|
|
localAliases, err := i.parseLocalAliases(entity.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load local aliases from storage: %v", err)
|
|
}
|
|
if localAliases != nil {
|
|
for _, alias := range localAliases.Aliases {
|
|
entity.UpsertAlias(alias)
|
|
}
|
|
}
|
|
|
|
// Only update MemDB and don't hit the storage again
|
|
err = i.upsertEntity(nsCtx, entity, nil, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update entity in MemDB: %w", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Let all go routines finish
|
|
wg.Wait()
|
|
|
|
// Flatten the map into a list of keys, in order to log them
|
|
duplicatedAccessorsList := make([]string, len(duplicatedAccessors))
|
|
accessorCounter := 0
|
|
for accessor := range duplicatedAccessors {
|
|
duplicatedAccessorsList[accessorCounter] = accessor
|
|
accessorCounter++
|
|
}
|
|
|
|
if len(duplicatedAccessorsList) > 0 {
|
|
i.logger.Warn("One or more entities have multiple aliases on the same mount(s), remove duplicates to avoid ACL templating issues", "mount_accessors", duplicatedAccessorsList)
|
|
}
|
|
|
|
if i.logger.IsInfo() {
|
|
i.logger.Info("entities restored")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getAccessorsOnDuplicateAliases returns a list of accessors by checking aliases in
|
|
// the passed in list which belong to the same accessor(s)
|
|
func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string {
|
|
accessorCounts := make(map[string]int)
|
|
var mountAccessors []string
|
|
|
|
for _, alias := range aliases {
|
|
accessorCounts[alias.MountAccessor] += 1
|
|
}
|
|
|
|
for accessor, accessorCount := range accessorCounts {
|
|
if accessorCount > 1 {
|
|
mountAccessors = append(mountAccessors, accessor)
|
|
}
|
|
}
|
|
|
|
return mountAccessors
|
|
}
|
|
|
|
// upsertEntityInTxn either creates or updates an existing entity. The
|
|
// operations will be updated in both MemDB and storage. If 'persist' is set to
|
|
// false, then storage will not be updated. When an alias is transferred from
|
|
// one entity to another, both the source and destination entities should get
|
|
// updated, in which case, callers should send in both entity and
|
|
// previousEntity.
|
|
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
|
|
defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
|
|
var err error
|
|
|
|
if txn == nil {
|
|
return errors.New("txn is nil")
|
|
}
|
|
|
|
if entity == nil {
|
|
return errors.New("entity is nil")
|
|
}
|
|
|
|
if entity.NamespaceID == "" {
|
|
entity.NamespaceID = namespace.RootNamespaceID
|
|
}
|
|
|
|
if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID {
|
|
return errors.New("entity and previous entity are not in the same namespace")
|
|
}
|
|
|
|
aliasFactors := make([]string, len(entity.Aliases))
|
|
|
|
for index, alias := range entity.Aliases {
|
|
// Verify that alias is not associated to a different one already
|
|
aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if alias.NamespaceID == "" {
|
|
alias.NamespaceID = namespace.RootNamespaceID
|
|
}
|
|
|
|
switch {
|
|
case aliasByFactors == nil:
|
|
// Not found, no merging needed, just check namespace
|
|
if alias.NamespaceID != entity.NamespaceID {
|
|
return errors.New("alias and entity are not in the same namespace")
|
|
}
|
|
|
|
case aliasByFactors.CanonicalID == entity.ID:
|
|
// Lookup found the same entity, so it's already attached to the
|
|
// right place
|
|
if aliasByFactors.NamespaceID != entity.NamespaceID {
|
|
return errors.New("alias from factors and entity are not in the same namespace")
|
|
}
|
|
|
|
case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID:
|
|
// previousEntity isn't upserted yet so may still contain the old
|
|
// alias reference in memdb if it was just changed; validate
|
|
// whether or not it's _actually_ still tied to the entity
|
|
var found bool
|
|
for _, prevEntAlias := range previousEntity.Aliases {
|
|
if prevEntAlias.ID == alias.ID {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
// If we didn't find the alias still tied to previousEntity, we
|
|
// shouldn't use the merging logic and should bail
|
|
if !found {
|
|
break
|
|
}
|
|
|
|
// Otherwise it's still tied to previousEntity and fall through
|
|
// into merging. We don't need a namespace check here as existing
|
|
// checks when creating the aliases should ensure that all line up.
|
|
fallthrough
|
|
|
|
default:
|
|
i.logger.Warn("alias is already tied to a different entity; these entities are being merged", "alias_id", alias.ID, "other_entity_id", aliasByFactors.CanonicalID, "entity_aliases", entity.Aliases, "alias_by_factors", aliasByFactors)
|
|
|
|
respErr, intErr := i.mergeEntity(ctx, txn, entity, []string{aliasByFactors.CanonicalID}, true, false, true, persist)
|
|
switch {
|
|
case respErr != nil:
|
|
return respErr
|
|
case intErr != nil:
|
|
return intErr
|
|
}
|
|
|
|
// The entity and aliases will be loaded into memdb and persisted
|
|
// as a result of the merge so we are done here
|
|
return nil
|
|
}
|
|
|
|
if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) {
|
|
i.logger.Warn(errDuplicateIdentityName.Error(), "alias_name", alias.Name, "mount_accessor", alias.MountAccessor, "local", alias.Local, "entity_name", entity.Name, "action", "delete one of the duplicate aliases")
|
|
if !i.disableLowerCasedNames {
|
|
return errDuplicateIdentityName
|
|
}
|
|
}
|
|
|
|
// Insert or update alias in MemDB using the transaction created above
|
|
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor
|
|
}
|
|
|
|
// If previous entity is set, update it in MemDB and persist it
|
|
if previousEntity != nil {
|
|
err = i.MemDBUpsertEntityInTxn(txn, previousEntity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if persist {
|
|
// Persist the previous entity object
|
|
if err := i.persistEntity(ctx, previousEntity); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Insert or update entity in MemDB using the transaction created above
|
|
err = i.MemDBUpsertEntityInTxn(txn, entity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if persist {
|
|
if err := i.persistEntity(ctx, entity); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) {
|
|
if !lAlias.Local {
|
|
return nil, fmt.Errorf("alias is not local")
|
|
}
|
|
|
|
mountValidationResp := i.router.ValidateMountByAccessor(lAlias.MountAccessor)
|
|
if mountValidationResp == nil {
|
|
return nil, fmt.Errorf("invalid mount accessor %q", lAlias.MountAccessor)
|
|
}
|
|
|
|
if !mountValidationResp.MountLocal {
|
|
return nil, fmt.Errorf("mount accessor %q is not local", lAlias.MountAccessor)
|
|
}
|
|
|
|
alias, err := i.MemDBAliasByFactors(lAlias.MountAccessor, lAlias.Name, false, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if alias == nil {
|
|
alias = &identity.Alias{}
|
|
}
|
|
|
|
alias.CanonicalID = entity.ID
|
|
alias.Name = lAlias.Name
|
|
alias.MountAccessor = lAlias.MountAccessor
|
|
alias.Metadata = lAlias.Metadata
|
|
alias.MountPath = mountValidationResp.MountPath
|
|
alias.MountType = mountValidationResp.MountType
|
|
alias.Local = lAlias.Local
|
|
alias.CustomMetadata = lAlias.CustomMetadata
|
|
|
|
if err := i.sanitizeAlias(ctx, alias); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
entity.UpsertAlias(alias)
|
|
|
|
localAliases, err := i.parseLocalAliases(entity.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if localAliases == nil {
|
|
localAliases = &identity.LocalAliases{}
|
|
}
|
|
|
|
updated := false
|
|
for i, item := range localAliases.Aliases {
|
|
if item.ID == alias.ID {
|
|
localAliases.Aliases[i] = alias
|
|
updated = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !updated {
|
|
localAliases.Aliases = append(localAliases.Aliases, alias)
|
|
}
|
|
|
|
marshaledAliases, err := ptypes.MarshalAny(localAliases)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{
|
|
ID: entity.ID,
|
|
Message: marshaledAliases,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if updateDb {
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil {
|
|
return nil, err
|
|
}
|
|
txn.Commit()
|
|
}
|
|
|
|
return alias, nil
|
|
}
|
|
|
|
// cacheTemporaryEntity stores in secondary's storage, the entity returned by
|
|
// the primary cluster via the CreateEntity RPC. This is so that the secondary
|
|
// cluster knows and retains information about the existence of these entities
|
|
// before the replication invalidation informs the secondary of the same. This
|
|
// also happens to cover the case where the secondary's replication is lagging
|
|
// behind the primary by hours and/or days which sometimes may happen. Even if
|
|
// the nodes of the secondary are restarted in the interim, the cluster would
|
|
// still be aware of the entities. This temporary cache will be cleared when the
|
|
// invalidation hits the secondary nodes.
|
|
func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identity.Entity) error {
|
|
if i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) && i.localNode.HAState() != consts.PerfStandby {
|
|
marshaledEntity, err := ptypes.MarshalAny(entity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{
|
|
ID: entity.ID + tmpSuffix,
|
|
Message: marshaledEntity,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error {
|
|
// If the entity that is passed into this function is resulting from a memdb
|
|
// query without cloning, then modifying it will result in a direct DB edit,
|
|
// bypassing the transaction. To avoid any surprises arising from this
|
|
// effect, work on a replica of the entity struct.
|
|
var err error
|
|
entity, err = entity.Clone()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Separate the local and non-local aliases.
|
|
var localAliases []*identity.Alias
|
|
var nonLocalAliases []*identity.Alias
|
|
for _, alias := range entity.Aliases {
|
|
switch alias.Local {
|
|
case true:
|
|
localAliases = append(localAliases, alias)
|
|
default:
|
|
nonLocalAliases = append(nonLocalAliases, alias)
|
|
}
|
|
}
|
|
|
|
// Store the entity with non-local aliases.
|
|
entity.Aliases = nonLocalAliases
|
|
marshaledEntity, err := ptypes.MarshalAny(entity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := i.entityPacker.PutItem(ctx, &storagepacker.Item{
|
|
ID: entity.ID,
|
|
Message: marshaledEntity,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(localAliases) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Store the local aliases separately.
|
|
aliases := &identity.LocalAliases{
|
|
Aliases: localAliases,
|
|
}
|
|
|
|
marshaledAliases, err := ptypes.MarshalAny(aliases)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := i.localAliasPacker.PutItem(ctx, &storagepacker.Item{
|
|
ID: entity.ID,
|
|
Message: marshaledAliases,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// upsertEntity either creates or updates an existing entity. The operations
|
|
// will be updated in both MemDB and storage. If 'persist' is set to false,
|
|
// then storage will not be updated. When an alias is transferred from one
|
|
// entity to another, both the source and destination entities should get
|
|
// updated, in which case, callers should send in both entity and
|
|
// previousEntity.
|
|
func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
|
|
defer metrics.MeasureSince([]string{"identity", "upsert_entity"}, time.Now())
|
|
|
|
// Create a MemDB transaction to update both alias and entity
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
|
|
err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBUpsertAliasInTxn(txn *memdb.Txn, alias *identity.Alias, groupAlias bool) error {
|
|
if txn == nil {
|
|
return fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if alias == nil {
|
|
return fmt.Errorf("alias is nil")
|
|
}
|
|
|
|
if alias.NamespaceID == "" {
|
|
alias.NamespaceID = namespace.RootNamespaceID
|
|
}
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
aliasRaw, err := txn.First(tableName, "id", alias.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to lookup alias from memdb using alias ID: %w", err)
|
|
}
|
|
|
|
if aliasRaw != nil {
|
|
err = txn.Delete(tableName, aliasRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete alias from memdb: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := txn.Insert(tableName, alias); err != nil {
|
|
return fmt.Errorf("failed to update alias into memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBAliasByIDInTxn(txn *memdb.Txn, aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
aliasRaw, err := txn.First(tableName, "id", aliasID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch alias from memdb using alias ID: %w", err)
|
|
}
|
|
|
|
if aliasRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
alias, ok := aliasRaw.(*identity.Alias)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched alias")
|
|
}
|
|
|
|
if clone {
|
|
return alias.Clone()
|
|
}
|
|
|
|
return alias, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBAliasByID(aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBAliasByIDInTxn(txn, aliasID, clone, groupAlias)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
|
|
if aliasName == "" {
|
|
return nil, fmt.Errorf("missing alias name")
|
|
}
|
|
|
|
if mountAccessor == "" {
|
|
return nil, fmt.Errorf("missing mount accessor")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if aliasName == "" {
|
|
return nil, fmt.Errorf("missing alias name")
|
|
}
|
|
|
|
if mountAccessor == "" {
|
|
return nil, fmt.Errorf("missing mount accessor")
|
|
}
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %w", err)
|
|
}
|
|
|
|
if aliasRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
alias, ok := aliasRaw.(*identity.Alias)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched alias")
|
|
}
|
|
|
|
if clone {
|
|
return alias.Clone()
|
|
}
|
|
|
|
return alias, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBDeleteAliasByIDInTxn(txn *memdb.Txn, aliasID string, groupAlias bool) error {
|
|
if aliasID == "" {
|
|
return nil
|
|
}
|
|
|
|
if txn == nil {
|
|
return fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, groupAlias)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if alias == nil {
|
|
return nil
|
|
}
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
err = txn.Delete(tableName, alias)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete alias from memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBAliases(ws memdb.WatchSet, groupAlias bool) (memdb.ResultIterator, error) {
|
|
txn := i.db.Txn(false)
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
iter, err := txn.Get(tableName, "id")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ws.Add(iter.WatchCh())
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBUpsertEntityInTxn(txn *memdb.Txn, entity *identity.Entity) error {
|
|
if txn == nil {
|
|
return fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if entity == nil {
|
|
return fmt.Errorf("entity is nil")
|
|
}
|
|
|
|
if entity.NamespaceID == "" {
|
|
entity.NamespaceID = namespace.RootNamespaceID
|
|
}
|
|
|
|
entityRaw, err := txn.First(entitiesTable, "id", entity.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to lookup entity from memdb using entity id: %w", err)
|
|
}
|
|
|
|
if entityRaw != nil {
|
|
err = txn.Delete(entitiesTable, entityRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete entity from memdb: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := txn.Insert(entitiesTable, entity); err != nil {
|
|
return fmt.Errorf("failed to update entity into memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByIDInTxn(txn *memdb.Txn, entityID string, clone bool) (*identity.Entity, error) {
|
|
if entityID == "" {
|
|
return nil, fmt.Errorf("missing entity id")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
entityRaw, err := txn.First(entitiesTable, "id", entityID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch entity from memdb using entity id: %w", err)
|
|
}
|
|
|
|
if entityRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
entity, ok := entityRaw.(*identity.Entity)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched entity")
|
|
}
|
|
|
|
if clone {
|
|
return entity.Clone()
|
|
}
|
|
|
|
return entity, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) {
|
|
if entityID == "" {
|
|
return nil, fmt.Errorf("missing entity id")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBEntityByIDInTxn(txn, entityID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByName(ctx context.Context, entityName string, clone bool) (*identity.Entity, error) {
|
|
if entityName == "" {
|
|
return nil, fmt.Errorf("missing entity name")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBEntityByNameInTxn(ctx, txn, entityName, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByNameInTxn(ctx context.Context, txn *memdb.Txn, entityName string, clone bool) (*identity.Entity, error) {
|
|
if entityName == "" {
|
|
return nil, fmt.Errorf("missing entity name")
|
|
}
|
|
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
entityRaw, err := txn.First(entitiesTable, "name", ns.ID, entityName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch entity from memdb using entity name: %w", err)
|
|
}
|
|
|
|
if entityRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
entity, ok := entityRaw.(*identity.Entity)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched entity")
|
|
}
|
|
|
|
if clone {
|
|
return entity.Clone()
|
|
}
|
|
|
|
return entity, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBLocalAliasesByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Alias, error) {
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if bucketKey == "" {
|
|
return nil, fmt.Errorf("empty bucket key")
|
|
}
|
|
|
|
iter, err := txn.Get(entityAliasesTable, "local_bucket_key", bucketKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lookup aliases using local bucket entry key hash: %w", err)
|
|
}
|
|
|
|
var aliases []*identity.Alias
|
|
for item := iter.Next(); item != nil; item = iter.Next() {
|
|
alias := item.(*identity.Alias)
|
|
if alias.Local {
|
|
aliases = append(aliases, alias)
|
|
}
|
|
}
|
|
|
|
return aliases, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntitiesByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Entity, error) {
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if bucketKey == "" {
|
|
return nil, fmt.Errorf("empty bucket key")
|
|
}
|
|
|
|
entitiesIter, err := txn.Get(entitiesTable, "bucket_key", bucketKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lookup entities using bucket entry key hash: %w", err)
|
|
}
|
|
|
|
var entities []*identity.Entity
|
|
for item := entitiesIter.Next(); item != nil; item = entitiesIter.Next() {
|
|
entity, err := item.(*identity.Entity).Clone()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
entities = append(entities, entity)
|
|
}
|
|
|
|
return entities, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByMergedEntityID(mergedEntityID string, clone bool) (*identity.Entity, error) {
|
|
if mergedEntityID == "" {
|
|
return nil, fmt.Errorf("missing merged entity id")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
entityRaw, err := txn.First(entitiesTable, "merged_entity_ids", mergedEntityID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch entity from memdb using merged entity id: %w", err)
|
|
}
|
|
|
|
if entityRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
entity, ok := entityRaw.(*identity.Entity)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched entity")
|
|
}
|
|
|
|
if clone {
|
|
return entity.Clone()
|
|
}
|
|
|
|
return entity, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Entity, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if alias == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBEntityByAliasIDInTxn(txn, aliasID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBDeleteEntityByID(entityID string) error {
|
|
if entityID == "" {
|
|
return nil
|
|
}
|
|
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
|
|
err := i.MemDBDeleteEntityByIDInTxn(txn, entityID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBDeleteEntityByIDInTxn(txn *memdb.Txn, entityID string) error {
|
|
if entityID == "" {
|
|
return nil
|
|
}
|
|
|
|
if txn == nil {
|
|
return fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
entity, err := i.MemDBEntityByIDInTxn(txn, entityID, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if entity == nil {
|
|
return nil
|
|
}
|
|
|
|
err = txn.Delete(entitiesTable, entity)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete entity from memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) sanitizeAlias(ctx context.Context, alias *identity.Alias) error {
|
|
var err error
|
|
|
|
if alias == nil {
|
|
return fmt.Errorf("alias is nil")
|
|
}
|
|
|
|
// Alias must always be tied to a canonical object
|
|
if alias.CanonicalID == "" {
|
|
return fmt.Errorf("missing canonical ID")
|
|
}
|
|
|
|
// Alias must have a name
|
|
if alias.Name == "" {
|
|
return fmt.Errorf("missing alias name %q", alias.Name)
|
|
}
|
|
|
|
// Alias metadata should always be map[string]string
|
|
err = validateMetadata(alias.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid alias metadata: %w", err)
|
|
}
|
|
|
|
// Create an ID if there isn't one already
|
|
if alias.ID == "" {
|
|
alias.ID, err = uuid.GenerateUUID()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate alias ID")
|
|
}
|
|
|
|
alias.LocalBucketKey = i.localAliasPacker.BucketKey(alias.CanonicalID)
|
|
}
|
|
|
|
if alias.NamespaceID == "" {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
alias.NamespaceID = ns.ID
|
|
}
|
|
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ns.ID != alias.NamespaceID {
|
|
return errors.New("alias belongs to a different namespace")
|
|
}
|
|
|
|
// Set the creation and last update times
|
|
if alias.CreationTime == nil {
|
|
alias.CreationTime = ptypes.TimestampNow()
|
|
alias.LastUpdateTime = alias.CreationTime
|
|
} else {
|
|
alias.LastUpdateTime = ptypes.TimestampNow()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) sanitizeEntity(ctx context.Context, entity *identity.Entity) error {
|
|
var err error
|
|
|
|
if entity == nil {
|
|
return fmt.Errorf("entity is nil")
|
|
}
|
|
|
|
// Create an ID if there isn't one already
|
|
if entity.ID == "" {
|
|
entity.ID, err = uuid.GenerateUUID()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate entity id")
|
|
}
|
|
|
|
// Set the storage bucket key in entity
|
|
entity.BucketKey = i.entityPacker.BucketKey(entity.ID)
|
|
}
|
|
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if entity.NamespaceID == "" {
|
|
entity.NamespaceID = ns.ID
|
|
}
|
|
if ns.ID != entity.NamespaceID {
|
|
return errors.New("entity does not belong to this namespace")
|
|
}
|
|
|
|
// Create a name if there isn't one already
|
|
if entity.Name == "" {
|
|
entity.Name, err = i.generateName(ctx, "entity")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate entity name")
|
|
}
|
|
}
|
|
|
|
// Entity metadata should always be map[string]string
|
|
err = validateMetadata(entity.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid entity metadata: %w", err)
|
|
}
|
|
|
|
// Set the creation and last update times
|
|
if entity.CreationTime == nil {
|
|
entity.CreationTime = ptypes.TimestampNow()
|
|
entity.LastUpdateTime = entity.CreationTime
|
|
} else {
|
|
entity.LastUpdateTime = ptypes.TimestampNow()
|
|
}
|
|
|
|
// Ensure that MFASecrets is non-nil at any time. This is useful when MFA
|
|
// secret generation procedures try to append MFA info to entity.
|
|
if entity.MFASecrets == nil {
|
|
entity.MFASecrets = make(map[string]*mfa.Secret)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) sanitizeAndUpsertGroup(ctx context.Context, group *identity.Group, previousGroup *identity.Group, memberGroupIDs []string) error {
|
|
var err error
|
|
|
|
if group == nil {
|
|
return fmt.Errorf("group is nil")
|
|
}
|
|
|
|
// Create an ID if there isn't one already
|
|
if group.ID == "" {
|
|
group.ID, err = uuid.GenerateUUID()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate group id")
|
|
}
|
|
|
|
// Set the hash value of the storage bucket key in group
|
|
group.BucketKey = i.groupPacker.BucketKey(group.ID)
|
|
}
|
|
|
|
if group.NamespaceID == "" {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
group.NamespaceID = ns.ID
|
|
}
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ns.ID != group.NamespaceID {
|
|
return errors.New("group does not belong to this namespace")
|
|
}
|
|
|
|
// Create a name if there isn't one already
|
|
if group.Name == "" {
|
|
group.Name, err = i.generateName(ctx, "group")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate group name")
|
|
}
|
|
}
|
|
|
|
// Entity metadata should always be map[string]string
|
|
err = validateMetadata(group.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid group metadata: %w", err)
|
|
}
|
|
|
|
// Set the creation and last update times
|
|
if group.CreationTime == nil {
|
|
group.CreationTime = ptypes.TimestampNow()
|
|
group.LastUpdateTime = group.CreationTime
|
|
} else {
|
|
group.LastUpdateTime = ptypes.TimestampNow()
|
|
}
|
|
|
|
// Remove duplicate entity IDs and check if all IDs are valid
|
|
group.MemberEntityIDs = strutil.RemoveDuplicates(group.MemberEntityIDs, false)
|
|
for _, entityID := range group.MemberEntityIDs {
|
|
entity, err := i.MemDBEntityByID(entityID, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to validate entity ID %q: %w", entityID, err)
|
|
}
|
|
if entity == nil {
|
|
return fmt.Errorf("invalid entity ID %q", entityID)
|
|
}
|
|
}
|
|
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
|
|
var currentMemberGroupIDs []string
|
|
var currentMemberGroups []*identity.Group
|
|
|
|
// If there are no member group IDs supplied, then it shouldn't be
|
|
// processed. If an empty set of member group IDs are supplied, then it
|
|
// should be processed. Hence the nil check instead of the length check.
|
|
if memberGroupIDs == nil {
|
|
goto ALIAS
|
|
}
|
|
|
|
memberGroupIDs = strutil.RemoveDuplicates(memberGroupIDs, false)
|
|
|
|
// For those group member IDs that are removed from the list, remove current
|
|
// group ID as their respective ParentGroupID.
|
|
|
|
// Get the current MemberGroups IDs for this group
|
|
currentMemberGroups, err = i.MemDBGroupsByParentGroupID(group.ID, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, currentMemberGroup := range currentMemberGroups {
|
|
currentMemberGroupIDs = append(currentMemberGroupIDs, currentMemberGroup.ID)
|
|
}
|
|
|
|
// Update parent group IDs in the removed members
|
|
for _, currentMemberGroupID := range currentMemberGroupIDs {
|
|
if strutil.StrListContains(memberGroupIDs, currentMemberGroupID) {
|
|
continue
|
|
}
|
|
|
|
currentMemberGroup, err := i.MemDBGroupByID(currentMemberGroupID, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if currentMemberGroup == nil {
|
|
return fmt.Errorf("invalid member group ID %q", currentMemberGroupID)
|
|
}
|
|
|
|
// Remove group ID from the parent group IDs
|
|
currentMemberGroup.ParentGroupIDs = strutil.StrListDelete(currentMemberGroup.ParentGroupIDs, group.ID)
|
|
|
|
err = i.UpsertGroupInTxn(ctx, txn, currentMemberGroup, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// After the group lock is held, make membership updates to all the
|
|
// relevant groups
|
|
for _, memberGroupID := range memberGroupIDs {
|
|
memberGroup, err := i.MemDBGroupByID(memberGroupID, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if memberGroup == nil {
|
|
return fmt.Errorf("invalid member group ID %q", memberGroupID)
|
|
}
|
|
|
|
// Skip if memberGroupID is already a member of group.ID
|
|
if strutil.StrListContains(memberGroup.ParentGroupIDs, group.ID) {
|
|
continue
|
|
}
|
|
|
|
// Ensure that adding memberGroupID does not lead to cyclic
|
|
// relationships
|
|
// Detect self loop
|
|
if group.ID == memberGroupID {
|
|
return fmt.Errorf("member group ID %q is same as the ID of the group", group.ID)
|
|
}
|
|
|
|
groupByID, err := i.MemDBGroupByID(group.ID, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If group is nil, that means that a group doesn't already exist and its
|
|
// okay to add any group as its member group.
|
|
if groupByID != nil {
|
|
// If adding the memberGroupID to groupID creates a cycle, then groupID must
|
|
// be a hop in that loop. Start a DFS traversal from memberGroupID and see if
|
|
// it reaches back to groupID. If it does, then it's a loop.
|
|
|
|
// Created a visited set
|
|
visited := make(map[string]bool)
|
|
cycleDetected, err := i.detectCycleDFS(visited, groupByID.ID, memberGroupID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to perform cyclic relationship detection for member group ID %q", memberGroupID)
|
|
}
|
|
if cycleDetected {
|
|
return fmt.Errorf("cyclic relationship detected for member group ID %q", memberGroupID)
|
|
}
|
|
}
|
|
|
|
memberGroup.ParentGroupIDs = append(memberGroup.ParentGroupIDs, group.ID)
|
|
|
|
// This technically is not upsert. It is only update, only the method
|
|
// name is upsert here.
|
|
err = i.UpsertGroupInTxn(ctx, txn, memberGroup, true)
|
|
if err != nil {
|
|
// Ideally we would want to revert the whole operation in case of
|
|
// errors while persisting in member groups. But there is no
|
|
// storage transaction support yet. When we do have it, this will need
|
|
// an update.
|
|
return err
|
|
}
|
|
}
|
|
|
|
ALIAS:
|
|
// Sanitize the group alias
|
|
if group.Alias != nil {
|
|
group.Alias.CanonicalID = group.ID
|
|
err = i.sanitizeAlias(ctx, group.Alias)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// If previousGroup is not nil, we are moving the alias from the previous
|
|
// group to the new one. As a result we need to upsert both in the context
|
|
// of this same transaction.
|
|
if previousGroup != nil {
|
|
err = i.UpsertGroupInTxn(ctx, txn, previousGroup, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = i.UpsertGroupInTxn(ctx, txn, group, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) deleteAliasesInEntityInTxn(txn *memdb.Txn, entity *identity.Entity, aliases []*identity.Alias) error {
|
|
if entity == nil {
|
|
return fmt.Errorf("entity is nil")
|
|
}
|
|
|
|
if txn == nil {
|
|
return fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
var remainList []*identity.Alias
|
|
var removeList []*identity.Alias
|
|
for _, item := range entity.Aliases {
|
|
remove := false
|
|
for _, alias := range aliases {
|
|
if alias.ID == item.ID {
|
|
remove = true
|
|
}
|
|
}
|
|
if remove {
|
|
removeList = append(removeList, item)
|
|
} else {
|
|
remainList = append(remainList, item)
|
|
}
|
|
}
|
|
|
|
// Remove identity indices from aliases table for those that needs to
|
|
// be removed
|
|
for _, alias := range removeList {
|
|
err := i.MemDBDeleteAliasByIDInTxn(txn, alias.ID, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Update the entity with remaining items
|
|
entity.Aliases = remainList
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateMeta validates a set of key/value pairs from the agent config
|
|
func validateMetadata(meta map[string]string) error {
|
|
if len(meta) > metaMaxKeyPairs {
|
|
return fmt.Errorf("metadata cannot contain more than %d key/value pairs", metaMaxKeyPairs)
|
|
}
|
|
|
|
for key, value := range meta {
|
|
if err := validateMetaPair(key, value); err != nil {
|
|
return fmt.Errorf("failed to load metadata pair (%q, %q): %w", key, value, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateMetaPair checks that the given key/value pair is in a valid format
|
|
func validateMetaPair(key, value string) error {
|
|
if key == "" {
|
|
return fmt.Errorf("key cannot be blank")
|
|
}
|
|
if !metaKeyFormatRegEx(key) {
|
|
return fmt.Errorf("key contains invalid characters")
|
|
}
|
|
if len(key) > metaKeyMaxLength {
|
|
return fmt.Errorf("key is too long (limit: %d characters)", metaKeyMaxLength)
|
|
}
|
|
if strings.HasPrefix(key, metaKeyReservedPrefix) {
|
|
return fmt.Errorf("key prefix %q is reserved for internal use", metaKeyReservedPrefix)
|
|
}
|
|
if len(value) > metaValueMaxLength {
|
|
return fmt.Errorf("value is too long (limit: %d characters)", metaValueMaxLength)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByNameInTxn(ctx context.Context, txn *memdb.Txn, groupName string, clone bool) (*identity.Group, error) {
|
|
if groupName == "" {
|
|
return nil, fmt.Errorf("missing group name")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
groupRaw, err := txn.First(groupsTable, "name", ns.ID, groupName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch group from memdb using group name: %w", err)
|
|
}
|
|
|
|
if groupRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
group, ok := groupRaw.(*identity.Group)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched group")
|
|
}
|
|
|
|
if clone {
|
|
return group.Clone()
|
|
}
|
|
|
|
return group, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByName(ctx context.Context, groupName string, clone bool) (*identity.Group, error) {
|
|
if groupName == "" {
|
|
return nil, fmt.Errorf("missing group name")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBGroupByNameInTxn(ctx, txn, groupName, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) UpsertGroup(ctx context.Context, group *identity.Group, persist bool) error {
|
|
defer metrics.MeasureSince([]string{"identity", "upsert_group"}, time.Now())
|
|
|
|
txn := i.db.Txn(true)
|
|
defer txn.Abort()
|
|
|
|
err := i.UpsertGroupInTxn(ctx, txn, group, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, group *identity.Group, persist bool) error {
|
|
defer metrics.MeasureSince([]string{"identity", "upsert_group_txn"}, time.Now())
|
|
|
|
var err error
|
|
|
|
if txn == nil {
|
|
return fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
if group == nil {
|
|
return fmt.Errorf("group is nil")
|
|
}
|
|
|
|
// Increment the modify index of the group
|
|
group.ModifyIndex++
|
|
|
|
// Clear the old alias from memdb
|
|
groupClone, err := i.MemDBGroupByID(group.ID, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if groupClone != nil && groupClone.Alias != nil {
|
|
err = i.MemDBDeleteAliasByIDInTxn(txn, groupClone.Alias.ID, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Add the new alias to memdb
|
|
if group.Alias != nil {
|
|
err = i.MemDBUpsertAliasInTxn(txn, group.Alias, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Insert or update group in MemDB using the transaction created above
|
|
err = i.MemDBUpsertGroupInTxn(txn, group)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if persist {
|
|
groupAsAny, err := ptypes.MarshalAny(group)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
item := &storagepacker.Item{
|
|
ID: group.ID,
|
|
Message: groupAsAny,
|
|
}
|
|
|
|
sent, err := i.groupUpdater.SendGroupUpdate(ctx, group)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !sent {
|
|
if err := i.groupPacker.PutItem(ctx, item); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBUpsertGroupInTxn(txn *memdb.Txn, group *identity.Group) error {
|
|
if txn == nil {
|
|
return fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if group == nil {
|
|
return fmt.Errorf("group is nil")
|
|
}
|
|
|
|
if group.NamespaceID == "" {
|
|
group.NamespaceID = namespace.RootNamespaceID
|
|
}
|
|
|
|
groupRaw, err := txn.First(groupsTable, "id", group.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to lookup group from memdb using group id: %w", err)
|
|
}
|
|
|
|
if groupRaw != nil {
|
|
err = txn.Delete(groupsTable, groupRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete group from memdb: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := txn.Insert(groupsTable, group); err != nil {
|
|
return fmt.Errorf("failed to update group into memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBDeleteGroupByIDInTxn(txn *memdb.Txn, groupID string) error {
|
|
if groupID == "" {
|
|
return nil
|
|
}
|
|
|
|
if txn == nil {
|
|
return fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
group, err := i.MemDBGroupByIDInTxn(txn, groupID, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if group == nil {
|
|
return nil
|
|
}
|
|
|
|
err = txn.Delete("groups", group)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete group from memdb: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByIDInTxn(txn *memdb.Txn, groupID string, clone bool) (*identity.Group, error) {
|
|
if groupID == "" {
|
|
return nil, fmt.Errorf("missing group ID")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
groupRaw, err := txn.First(groupsTable, "id", groupID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch group from memdb using group ID: %w", err)
|
|
}
|
|
|
|
if groupRaw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
group, ok := groupRaw.(*identity.Group)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to declare the type of fetched group")
|
|
}
|
|
|
|
if clone {
|
|
return group.Clone()
|
|
}
|
|
|
|
return group, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByID(groupID string, clone bool) (*identity.Group, error) {
|
|
if groupID == "" {
|
|
return nil, fmt.Errorf("missing group ID")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBGroupByIDInTxn(txn, groupID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupsByParentGroupIDInTxn(txn *memdb.Txn, memberGroupID string, clone bool) ([]*identity.Group, error) {
|
|
if memberGroupID == "" {
|
|
return nil, fmt.Errorf("missing member group ID")
|
|
}
|
|
|
|
groupsIter, err := txn.Get(groupsTable, "parent_group_ids", memberGroupID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lookup groups using member group ID: %w", err)
|
|
}
|
|
|
|
var groups []*identity.Group
|
|
for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
|
|
entry := group.(*identity.Group)
|
|
if clone {
|
|
entry, err = entry.Clone()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
groups = append(groups, entry)
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupsByParentGroupID(memberGroupID string, clone bool) ([]*identity.Group, error) {
|
|
if memberGroupID == "" {
|
|
return nil, fmt.Errorf("missing member group ID")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBGroupsByParentGroupIDInTxn(txn, memberGroupID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupsByMemberEntityID(entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) {
|
|
txn := i.db.Txn(false)
|
|
defer txn.Abort()
|
|
|
|
return i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, clone, externalOnly)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupsByMemberEntityIDInTxn(txn *memdb.Txn, entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) {
|
|
if entityID == "" {
|
|
return nil, fmt.Errorf("missing entity ID")
|
|
}
|
|
|
|
groupsIter, err := txn.Get(groupsTable, "member_entity_ids", entityID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lookup groups using entity ID: %w", err)
|
|
}
|
|
|
|
var groups []*identity.Group
|
|
for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
|
|
entry := group.(*identity.Group)
|
|
if externalOnly && entry.Type == groupTypeInternal {
|
|
continue
|
|
}
|
|
if clone {
|
|
entry, err = entry.Clone()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
groups = append(groups, entry)
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (i *IdentityStore) groupPoliciesByEntityID(entityID string) (map[string][]string, error) {
|
|
if entityID == "" {
|
|
return nil, fmt.Errorf("empty entity ID")
|
|
}
|
|
|
|
groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
visited := make(map[string]bool)
|
|
policies := make(map[string][]string)
|
|
for _, group := range groups {
|
|
err := i.collectPoliciesReverseDFS(group, visited, policies)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return policies, nil
|
|
}
|
|
|
|
func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) {
|
|
if entityID == "" {
|
|
return nil, nil, fmt.Errorf("empty entity ID")
|
|
}
|
|
|
|
groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
visited := make(map[string]bool)
|
|
var tGroups []*identity.Group
|
|
for _, group := range groups {
|
|
gGroups, err := i.collectGroupsReverseDFS(group, visited, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
tGroups = append(tGroups, gGroups...)
|
|
}
|
|
|
|
// Remove duplicates
|
|
groupMap := make(map[string]*identity.Group)
|
|
for _, group := range tGroups {
|
|
groupMap[group.ID] = group
|
|
}
|
|
|
|
tGroups = make([]*identity.Group, 0, len(groupMap))
|
|
for _, group := range groupMap {
|
|
tGroups = append(tGroups, group)
|
|
}
|
|
|
|
diff := diffGroups(groups, tGroups)
|
|
|
|
// For sanity
|
|
// There should not be any group that gets deleted
|
|
if len(diff.Deleted) != 0 {
|
|
return nil, nil, fmt.Errorf("failed to diff group memberships")
|
|
}
|
|
|
|
return diff.Unmodified, diff.New, nil
|
|
}
|
|
|
|
func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {
|
|
if group == nil {
|
|
return nil, fmt.Errorf("nil group")
|
|
}
|
|
|
|
// If traversal for a groupID is performed before, skip it
|
|
if visited[group.ID] {
|
|
return groups, nil
|
|
}
|
|
visited[group.ID] = true
|
|
|
|
groups = append(groups, group)
|
|
|
|
// Traverse all the parent groups
|
|
for _, parentGroupID := range group.ParentGroupIDs {
|
|
parentGroup, err := i.MemDBGroupByID(parentGroupID, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if parentGroup == nil {
|
|
continue
|
|
}
|
|
groups, err = i.collectGroupsReverseDFS(parentGroup, visited, groups)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to collect group at parent group ID %q", parentGroup.ID)
|
|
}
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (i *IdentityStore) collectPoliciesReverseDFS(group *identity.Group, visited map[string]bool, policies map[string][]string) error {
|
|
if group == nil {
|
|
return fmt.Errorf("nil group")
|
|
}
|
|
|
|
// If traversal for a groupID is performed before, skip it
|
|
if visited[group.ID] {
|
|
return nil
|
|
}
|
|
visited[group.ID] = true
|
|
|
|
policies[group.NamespaceID] = append(policies[group.NamespaceID], group.Policies...)
|
|
|
|
// Traverse all the parent groups
|
|
for _, parentGroupID := range group.ParentGroupIDs {
|
|
parentGroup, err := i.MemDBGroupByID(parentGroupID, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if parentGroup == nil {
|
|
continue
|
|
}
|
|
err = i.collectPoliciesReverseDFS(parentGroup, visited, policies)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to collect policies at parent group ID %q", parentGroup.ID)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *IdentityStore) detectCycleDFS(visited map[string]bool, startingGroupID, groupID string) (bool, error) {
|
|
// If the traversal reaches the startingGroupID, a loop is detected
|
|
if startingGroupID == groupID {
|
|
return true, nil
|
|
}
|
|
|
|
// If traversal for a groupID is performed before, skip it
|
|
if visited[groupID] {
|
|
return false, nil
|
|
}
|
|
visited[groupID] = true
|
|
|
|
group, err := i.MemDBGroupByID(groupID, true)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if group == nil {
|
|
return false, nil
|
|
}
|
|
|
|
// Fetch all groups in which groupID is present as a ParentGroupID. In
|
|
// other words, find all the subgroups of groupID.
|
|
memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// DFS traverse the member groups
|
|
for _, memberGroup := range memberGroups {
|
|
cycleDetected, err := i.detectCycleDFS(visited, startingGroupID, memberGroup.ID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to perform cycle detection at member group ID %q", memberGroup.ID)
|
|
}
|
|
if cycleDetected {
|
|
return true, fmt.Errorf("cycle detected at member group ID %q", memberGroup.ID)
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (i *IdentityStore) memberGroupIDsByID(groupID string) ([]string, error) {
|
|
var memberGroupIDs []string
|
|
memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, memberGroup := range memberGroups {
|
|
memberGroupIDs = append(memberGroupIDs, memberGroup.ID)
|
|
}
|
|
return memberGroupIDs, nil
|
|
}
|
|
|
|
func (i *IdentityStore) generateName(ctx context.Context, entryType string) (string, error) {
|
|
var name string
|
|
OUTER:
|
|
for {
|
|
randBytes, err := uuid.GenerateRandomBytes(4)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
name = fmt.Sprintf("%s_%s", entryType, fmt.Sprintf("%08x", randBytes[0:4]))
|
|
|
|
switch entryType {
|
|
case "entity":
|
|
entity, err := i.MemDBEntityByName(ctx, name, false)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if entity == nil {
|
|
break OUTER
|
|
}
|
|
case "group":
|
|
group, err := i.MemDBGroupByName(ctx, name, false)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if group == nil {
|
|
break OUTER
|
|
}
|
|
default:
|
|
return "", fmt.Errorf("unrecognized type %q", entryType)
|
|
}
|
|
}
|
|
|
|
return name, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupsByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Group, error) {
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("nil txn")
|
|
}
|
|
|
|
if bucketKey == "" {
|
|
return nil, fmt.Errorf("empty bucket key")
|
|
}
|
|
|
|
groupsIter, err := txn.Get(groupsTable, "bucket_key", bucketKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lookup groups using bucket entry key hash: %w", err)
|
|
}
|
|
|
|
var groups []*identity.Group
|
|
for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
|
|
groups = append(groups, group.(*identity.Group))
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Group, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
if txn == nil {
|
|
return nil, fmt.Errorf("txn is nil")
|
|
}
|
|
|
|
alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if alias == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
return i.MemDBGroupByIDInTxn(txn, alias.CanonicalID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) MemDBGroupByAliasID(aliasID string, clone bool) (*identity.Group, error) {
|
|
if aliasID == "" {
|
|
return nil, fmt.Errorf("missing alias ID")
|
|
}
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
return i.MemDBGroupByAliasIDInTxn(txn, aliasID, clone)
|
|
}
|
|
|
|
func (i *IdentityStore) refreshExternalGroupMembershipsByEntityID(ctx context.Context, entityID string, groupAliases []*logical.Alias, mountAccessor string) ([]*logical.Alias, error) {
|
|
defer metrics.MeasureSince([]string{"identity", "refresh_external_groups"}, time.Now())
|
|
|
|
if entityID == "" {
|
|
return nil, fmt.Errorf("empty entity ID")
|
|
}
|
|
|
|
refreshFunc := func(dryRun bool) (bool, []*logical.Alias, error) {
|
|
if !dryRun {
|
|
i.groupLock.Lock()
|
|
defer i.groupLock.Unlock()
|
|
}
|
|
|
|
txn := i.db.Txn(!dryRun)
|
|
defer txn.Abort()
|
|
|
|
oldGroups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, true, true)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
|
|
var newGroups []*identity.Group
|
|
var validAliases []*logical.Alias
|
|
for _, alias := range groupAliases {
|
|
aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, true, true)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
if aliasByFactors == nil {
|
|
continue
|
|
}
|
|
mappingGroup, err := i.MemDBGroupByAliasIDInTxn(txn, aliasByFactors.ID, true)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
if mappingGroup == nil {
|
|
return false, nil, fmt.Errorf("group unavailable for a valid alias ID %q", aliasByFactors.ID)
|
|
}
|
|
|
|
newGroups = append(newGroups, mappingGroup)
|
|
validAliases = append(validAliases, alias)
|
|
}
|
|
|
|
diff := diffGroups(oldGroups, newGroups)
|
|
|
|
// Add the entity ID to all the new groups
|
|
for _, group := range diff.New {
|
|
if group.Type != groupTypeExternal {
|
|
continue
|
|
}
|
|
|
|
// We need to update a group, if we are in a dry run we should
|
|
// report back that a change needs to take place.
|
|
if dryRun {
|
|
return true, nil, nil
|
|
}
|
|
|
|
i.logger.Debug("adding member entity ID to external group", "member_entity_id", entityID, "group_id", group.ID)
|
|
|
|
group.MemberEntityIDs = append(group.MemberEntityIDs, entityID)
|
|
|
|
err = i.UpsertGroupInTxn(ctx, txn, group, true)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
}
|
|
|
|
// Remove the entity ID from all the deleted groups
|
|
for _, group := range diff.Deleted {
|
|
if group.Type != groupTypeExternal {
|
|
continue
|
|
}
|
|
|
|
// If the external group is from a different mount, don't remove the
|
|
// entity ID from it.
|
|
if mountAccessor != "" && group.Alias != nil && group.Alias.MountAccessor != mountAccessor {
|
|
continue
|
|
}
|
|
|
|
// We need to update a group, if we are in a dry run we should
|
|
// report back that a change needs to take place.
|
|
if dryRun {
|
|
return true, nil, nil
|
|
}
|
|
|
|
i.logger.Debug("removing member entity ID from external group", "member_entity_id", entityID, "group_id", group.ID)
|
|
|
|
group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, entityID)
|
|
|
|
err = i.UpsertGroupInTxn(ctx, txn, group, true)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
}
|
|
|
|
txn.Commit()
|
|
return false, validAliases, nil
|
|
}
|
|
|
|
// dryRun
|
|
needsUpdate, validAliases, err := refreshFunc(true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if needsUpdate || len(groupAliases) > 0 {
|
|
i.logger.Debug("refreshing external group memberships", "entity_id", entityID, "group_aliases", groupAliases)
|
|
}
|
|
|
|
if !needsUpdate {
|
|
return validAliases, nil
|
|
}
|
|
|
|
// Run the update
|
|
_, validAliases, err = refreshFunc(false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return validAliases, nil
|
|
}
|
|
|
|
// diffGroups is used to diff two sets of groups
|
|
func diffGroups(old, new []*identity.Group) *groupDiff {
|
|
diff := &groupDiff{}
|
|
|
|
existing := make(map[string]*identity.Group)
|
|
for _, group := range old {
|
|
existing[group.ID] = group
|
|
}
|
|
|
|
for _, group := range new {
|
|
// Check if the entry in new is present in the old
|
|
_, ok := existing[group.ID]
|
|
|
|
// If its not present, then its a new entry
|
|
if !ok {
|
|
diff.New = append(diff.New, group)
|
|
continue
|
|
}
|
|
|
|
// If its present, it means that its unmodified
|
|
diff.Unmodified = append(diff.Unmodified, group)
|
|
|
|
// By deleting the unmodified from the old set, we could determine the
|
|
// ones that are stale by looking at the remaining ones.
|
|
delete(existing, group.ID)
|
|
}
|
|
|
|
// Any remaining entries must have been deleted
|
|
for _, me := range existing {
|
|
diff.Deleted = append(diff.Deleted, me)
|
|
}
|
|
|
|
return diff
|
|
}
|
|
|
|
func (i *IdentityStore) handleAliasListCommon(ctx context.Context, groupAlias bool) (*logical.Response, error) {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tableName := entityAliasesTable
|
|
if groupAlias {
|
|
tableName = groupAliasesTable
|
|
}
|
|
|
|
ws := memdb.NewWatchSet()
|
|
|
|
txn := i.db.Txn(false)
|
|
|
|
iter, err := txn.Get(tableName, "namespace_id", ns.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch iterator for aliases in memdb: %w", err)
|
|
}
|
|
|
|
ws.Add(iter.WatchCh())
|
|
|
|
var aliasIDs []string
|
|
aliasInfo := map[string]interface{}{}
|
|
|
|
type mountInfo struct {
|
|
MountType string
|
|
MountPath string
|
|
}
|
|
mountAccessorMap := map[string]mountInfo{}
|
|
|
|
for {
|
|
raw := iter.Next()
|
|
if raw == nil {
|
|
break
|
|
}
|
|
alias := raw.(*identity.Alias)
|
|
aliasIDs = append(aliasIDs, alias.ID)
|
|
aliasInfoEntry := map[string]interface{}{
|
|
"name": alias.Name,
|
|
"canonical_id": alias.CanonicalID,
|
|
"mount_accessor": alias.MountAccessor,
|
|
"custom_metadata": alias.CustomMetadata,
|
|
"local": alias.Local,
|
|
}
|
|
|
|
mi, ok := mountAccessorMap[alias.MountAccessor]
|
|
if ok {
|
|
aliasInfoEntry["mount_type"] = mi.MountType
|
|
aliasInfoEntry["mount_path"] = mi.MountPath
|
|
} else {
|
|
mi = mountInfo{}
|
|
if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
|
|
mi.MountType = mountValidationResp.MountType
|
|
mi.MountPath = mountValidationResp.MountPath
|
|
aliasInfoEntry["mount_type"] = mi.MountType
|
|
aliasInfoEntry["mount_path"] = mi.MountPath
|
|
}
|
|
mountAccessorMap[alias.MountAccessor] = mi
|
|
}
|
|
|
|
aliasInfo[alias.ID] = aliasInfoEntry
|
|
}
|
|
|
|
return logical.ListResponseWithInfo(aliasIDs, aliasInfo), nil
|
|
}
|
|
|
|
func (i *IdentityStore) countEntities() (int, error) {
|
|
txn := i.db.Txn(false)
|
|
|
|
iter, err := txn.Get(entitiesTable, "id")
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
count := 0
|
|
val := iter.Next()
|
|
for val != nil {
|
|
count++
|
|
val = iter.Next()
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// Sum up the number of entities belonging to each namespace (keyed by ID)
|
|
func (i *IdentityStore) countEntitiesByNamespace(ctx context.Context) (map[string]int, error) {
|
|
txn := i.db.Txn(false)
|
|
iter, err := txn.Get(entitiesTable, "id")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
byNamespace := make(map[string]int)
|
|
val := iter.Next()
|
|
for val != nil {
|
|
// Check if runtime exceeded.
|
|
select {
|
|
case <-ctx.Done():
|
|
return byNamespace, errors.New("context cancelled")
|
|
default:
|
|
break
|
|
}
|
|
|
|
// Count in the namespace attached to the entity.
|
|
entity := val.(*identity.Entity)
|
|
byNamespace[entity.NamespaceID] = byNamespace[entity.NamespaceID] + 1
|
|
val = iter.Next()
|
|
}
|
|
|
|
return byNamespace, nil
|
|
}
|
|
|
|
// Sum up the number of entities belonging to each mount point (keyed by accessor)
|
|
func (i *IdentityStore) countEntitiesByMountAccessor(ctx context.Context) (map[string]int, error) {
|
|
txn := i.db.Txn(false)
|
|
iter, err := txn.Get(entitiesTable, "id")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
byMountAccessor := make(map[string]int)
|
|
val := iter.Next()
|
|
for val != nil {
|
|
// Check if runtime exceeded.
|
|
select {
|
|
case <-ctx.Done():
|
|
return byMountAccessor, errors.New("context cancelled")
|
|
default:
|
|
break
|
|
}
|
|
|
|
// Count each alias separately; will translate to mount point and type
|
|
// in the caller.
|
|
entity := val.(*identity.Entity)
|
|
for _, alias := range entity.Aliases {
|
|
byMountAccessor[alias.MountAccessor] = byMountAccessor[alias.MountAccessor] + 1
|
|
}
|
|
val = iter.Next()
|
|
}
|
|
|
|
return byMountAccessor, nil
|
|
}
|