Refactor usages of Core in IdentityStore so they can be decoupled. (#12461)

This commit is contained in:
Nick Cabatoff 2021-08-30 21:31:11 +02:00 committed by GitHub
parent a8ee8854e3
commit 0762f9003d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 141 additions and 49 deletions

View File

@ -148,3 +148,12 @@ func (r ReplicationState) HasState(flag ReplicationState) bool { return r&flag !
func (r *ReplicationState) AddState(flag ReplicationState) { *r |= flag }
func (r *ReplicationState) ClearState(flag ReplicationState) { *r &= ^flag }
func (r *ReplicationState) ToggleState(flag ReplicationState) { *r ^= flag }
type HAState uint32
const (
_ HAState = iota
Standby
PerfStandby
Active
)

View File

@ -574,6 +574,17 @@ type Core struct {
enableResponseHeaderRaftNodeID bool
}
func (c *Core) HAState() consts.HAState {
switch {
case c.perfStandby:
return consts.PerfStandby
case c.standby:
return consts.Standby
default:
return consts.Active
}
}
// CoreConfig is used to parameterize a core
type CoreConfig struct {
entCoreConfig

View File

@ -277,7 +277,7 @@ func (d dynamicSystemView) EntityInfo(entityID string) (*logical.Entity, error)
alias := identity.ToSDKAlias(a)
// MountType is not stored with the entity and must be looked up
if mount := d.core.router.validateMountByAccessor(a.MountAccessor); mount != nil {
if mount := d.core.router.ValidateMountByAccessor(a.MountAccessor); mount != nil {
alias.MountType = mount.MountType
}

View File

@ -26,7 +26,6 @@ const (
var (
caseSensitivityKey = "casesensitivity"
sendGroupUpgrade = func(context.Context, *IdentityStore, *identity.Group) (bool, error) { return false, nil }
parseExtraEntityFromBucket = func(context.Context, *IdentityStore, *identity.Entity) (bool, error) { return false, nil }
addExtraEntityDataToResponse = func(*identity.Entity, map[string]interface{}) {}
)
@ -48,9 +47,15 @@ func (i *IdentityStore) resetDB(ctx context.Context) error {
func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendConfig, logger log.Logger) (*IdentityStore, error) {
iStore := &IdentityStore{
view: config.StorageView,
logger: logger,
core: core,
view: config.StorageView,
logger: logger,
router: core.router,
redirectAddr: core.redirectAddr,
localNode: core,
namespacer: core,
metrics: core.MetricSink(),
totpPersister: core,
groupUpdater: core,
}
// Create a memdb instance, which by default, operates on lower cased
@ -392,7 +397,7 @@ func (i *IdentityStore) parseEntityFromBucketItem(ctx context.Context, item *sto
persistNeeded = true
}
if persistNeeded && !i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) {
if persistNeeded && !i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) {
entityAsAny, err := ptypes.MarshalAny(&entity)
if err != nil {
return nil, err
@ -495,7 +500,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
return nil, fmt.Errorf("empty alias name")
}
mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor)
mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor)
if mountValidationResp == nil {
return nil, fmt.Errorf("invalid mount accessor %q", alias.MountAccessor)
}
@ -571,14 +576,14 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
}
// Emit a metric for the new entity
ns, err := NamespaceByID(ctx, entity.NamespaceID, i.core)
ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID)
var nsLabel metrics.Label
if err != nil {
nsLabel = metrics.Label{"namespace", "unknown"}
} else {
nsLabel = metricsutil.NamespaceLabel(ns)
}
i.core.MetricSink().IncrCounterWithLabels(
i.metrics.IncrCounterWithLabels(
[]string{"identity", "entity", "creation"},
1,
[]metrics.Label{

View File

@ -177,7 +177,7 @@ func (i *IdentityStore) handleAliasCreateUpdate() framework.OperationFunc {
}
// Look up the alias by factors; if it's found it's an update
mountEntry := i.core.router.MatchingMountByAccessor(mountAccessor)
mountEntry := i.router.MatchingMountByAccessor(mountAccessor)
if mountEntry == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil
}
@ -281,7 +281,7 @@ func (i *IdentityStore) handleAliasUpdate(ctx context.Context, req *logical.Requ
// namespace.
if name != alias.Name || mountAccessor != alias.MountAccessor {
// Check here to see if such an alias already exists, if so bail
mountEntry := i.core.router.MatchingMountByAccessor(mountAccessor)
mountEntry := i.router.MatchingMountByAccessor(mountAccessor)
if mountEntry == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil
}
@ -413,7 +413,7 @@ func (i *IdentityStore) handleAliasReadCommon(ctx context.Context, alias *identi
respData["merged_from_canonical_ids"] = alias.MergedFromCanonicalIDs
respData["namespace_id"] = alias.NamespaceID
if mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
respData["mount_path"] = mountValidationResp.MountPath
respData["mount_type"] = mountValidationResp.MountType
}

View File

@ -210,7 +210,7 @@ func TestIdentityStore_MemDBAliasIndexes(t *testing.T) {
t.Fatal("failed to create test identity store")
}
validateMountResp := is.core.router.validateMountByAccessor(githubAccessor)
validateMountResp := is.router.ValidateMountByAccessor(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}

View File

@ -374,7 +374,7 @@ func (i *IdentityStore) handleEntityReadCommon(ctx context.Context, entity *iden
aliasMap["creation_time"] = ptypes.TimestampString(alias.CreationTime)
aliasMap["last_update_time"] = ptypes.TimestampString(alias.LastUpdateTime)
if mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
aliasMap["mount_type"] = mountValidationResp.MountType
aliasMap["mount_path"] = mountValidationResp.MountPath
}
@ -696,7 +696,7 @@ func (i *IdentityStore) handlePathEntityListCommon(ctx context.Context, req *log
entry["mount_path"] = mi.MountPath
} else {
mi = mountInfo{}
if mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
mi.MountType = mountValidationResp.MountType
mi.MountPath = mountValidationResp.MountPath
entry["mount_type"] = mi.MountType
@ -765,7 +765,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
}
}
isPerfSecondaryOrStandby := i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.core.perfStandby
isPerfSecondaryOrStandby := i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) ||
i.localNode.HAState() == consts.PerfStandby
for _, fromEntityID := range fromEntityIDs {
if fromEntityID == toEntity.ID {
return errors.New("to_entity_id should not be present in from_entity_ids"), nil

View File

@ -503,7 +503,7 @@ func TestIdentityStore_MemDBImmutability(t *testing.T) {
ctx := namespace.RootContext(nil)
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(ctx, t)
validateMountResp := is.core.router.validateMountByAccessor(githubAccessor)
validateMountResp := is.router.ValidateMountByAccessor(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}
@ -780,7 +780,7 @@ func TestIdentityStore_MemDBEntityIndexes(t *testing.T) {
ctx := namespace.RootContext(nil)
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(ctx, t)
validateMountResp := is.core.router.validateMountByAccessor(githubAccessor)
validateMountResp := is.router.ValidateMountByAccessor(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}

View File

@ -188,7 +188,7 @@ func (i *IdentityStore) handleGroupAliasUpdateCommon(ctx context.Context, req *l
// Validate name/accessor whether new or update
{
mountEntry := i.core.router.MatchingMountByAccessor(mountAccessor)
mountEntry := i.router.MatchingMountByAccessor(mountAccessor)
if mountEntry == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil
}

View File

@ -346,7 +346,7 @@ func (i *IdentityStore) handleGroupReadCommon(ctx context.Context, group *identi
aliasMap["creation_time"] = ptypes.TimestampString(group.Alias.CreationTime)
aliasMap["last_update_time"] = ptypes.TimestampString(group.Alias.LastUpdateTime)
if mountValidationResp := i.core.router.validateMountByAccessor(group.Alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(group.Alias.MountAccessor); mountValidationResp != nil {
aliasMap["mount_path"] = mountValidationResp.MountPath
aliasMap["mount_type"] = mountValidationResp.MountType
}
@ -516,7 +516,7 @@ func (i *IdentityStore) handleGroupListCommon(ctx context.Context, byID bool) (*
entry["mount_path"] = mi.MountPath
} else {
mi = mountInfo{}
if mountValidationResp := i.core.router.validateMountByAccessor(group.Alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(group.Alias.MountAccessor); mountValidationResp != nil {
mi.MountType = mountValidationResp.MountType
mi.MountPath = mountValidationResp.MountPath
entry["mount_type"] = mi.MountType

View File

@ -313,7 +313,7 @@ func (i *IdentityStore) pathOIDCReadConfig(ctx context.Context, req *logical.Req
},
}
if i.core.redirectAddr == "" && c.Issuer == "" {
if i.redirectAddr == "" && c.Issuer == "" {
resp.AddWarning(`Both "issuer" and Vault's "api_addr" are empty. ` +
`The issuer claim in generated tokens will not be network reachable.`)
}
@ -416,7 +416,7 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
c.effectiveIssuer = c.Issuer
if c.effectiveIssuer == "" {
c.effectiveIssuer = i.core.redirectAddr
c.effectiveIssuer = i.redirectAddr
}
c.effectiveIssuer += "/v1/" + ns.Path + issuerPath
@ -1667,10 +1667,10 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
// based on key rotation times.
nextRun = now.Add(24 * time.Hour)
for _, ns := range i.listNamespaces() {
for _, ns := range i.namespacer.ListNamespaces() {
nsPath := ns.Path
s := i.core.router.MatchingStorageByAPIPath(ctx, nsPath+"identity/oidc")
s := i.router.MatchingStorageByAPIPath(ctx, nsPath+"identity/oidc")
if s == nil {
continue

View File

@ -293,7 +293,7 @@ func (i *IdentityStore) clientNamesReferencingTargetAssignmentName(ctx context.C
}
var names []string
for client, _ := range clients {
for client := range clients {
names = append(names, client)
}
sort.Strings(names)
@ -337,7 +337,7 @@ func (i *IdentityStore) clientNamesReferencingTargetKeyName(ctx context.Context,
}
var names []string
for client, _ := range clients {
for client := range clients {
names = append(names, client)
}
sort.Strings(names)
@ -959,7 +959,7 @@ func (i *IdentityStore) getOIDCProvider(ctx context.Context, s logical.Storage,
provider.effectiveIssuer = provider.Issuer
if provider.effectiveIssuer == "" {
provider.effectiveIssuer = i.core.redirectAddr
provider.effectiveIssuer = i.redirectAddr
}
provider.effectiveIssuer += "/v1/" + ns.Path + "identity/oidc/provider/" + name

View File

@ -0,0 +1,17 @@
// +build !enterprise
package vault
import (
"context"
"github.com/hashicorp/vault/helper/identity"
)
func (c *Core) PersistTOTPKey(context.Context, string, string, string) error {
return nil
}
func (c *Core) SendGroupUpdate(context.Context, *identity.Group) (bool, error) {
return false, nil
}

View File

@ -1,14 +1,18 @@
package vault
import (
"context"
"regexp"
"sync"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
)
@ -72,12 +76,17 @@ type IdentityStore struct {
// buckets
groupPacker *storagepacker.StoragePacker
// core is the pointer to Vault's core
core *Core
// disableLowerCaseNames indicates whether or not identity artifacts are
// operated case insensitively
disableLowerCasedNames bool
router *Router
redirectAddr string
localNode LocalNode
namespacer Namespacer
metrics metricsutil.Metrics
totpPersister TOTPPersister
groupUpdater GroupUpdater
}
type groupDiff struct {
@ -89,3 +98,29 @@ type groupDiff struct {
type casesensitivity struct {
DisableLowerCasedNames bool `json:"disable_lower_cased_names"`
}
type LocalNode interface {
ReplicationState() consts.ReplicationState
HAState() consts.HAState
}
var _ LocalNode = &Core{}
type Namespacer interface {
NamespaceByID(context.Context, string) (*namespace.Namespace, error)
ListNamespaces() []*namespace.Namespace
}
var _ Namespacer = &Core{}
type TOTPPersister interface {
PersistTOTPKey(ctx context.Context, configID string, entityID string, key string) error
}
var _ TOTPPersister = &Core{}
type GroupUpdater interface {
SendGroupUpdate(ctx context.Context, group *identity.Group) (bool, error)
}
var _ GroupUpdater = &Core{}

View File

@ -105,13 +105,13 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error {
continue
}
ns, err := NamespaceByID(ctx, group.NamespaceID, i.core)
ns, err := i.namespacer.NamespaceByID(ctx, group.NamespaceID)
if err != nil {
return err
}
if ns == nil {
// Remove dangling groups
if !(i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.core.perfStandby) {
if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) {
// Group's namespace doesn't exist anymore but the group
// from the namespace still exists.
i.logger.Warn("deleting group and its any existing aliases", "name", group.Name, "namespace_id", group.NamespaceID)
@ -273,13 +273,13 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
continue
}
ns, err := NamespaceByID(ctx, entity.NamespaceID, i.core)
ns, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID)
if err != nil {
return err
}
if ns == nil {
// Remove dangling entities
if !(i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.core.perfStandby) {
if !(i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.localNode.HAState() == consts.PerfStandby) {
// Entity's namespace doesn't exist anymore but the
// entity from the namespace still exists.
i.logger.Warn("deleting entity and its any existing aliases", "name", entity.Name, "namespace_id", entity.NamespaceID)
@ -1435,7 +1435,7 @@ func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, gr
Message: groupAsAny,
}
sent, err := sendGroupUpgrade(ctx, i, group)
sent, err := i.groupUpdater.SendGroupUpdate(ctx, group)
if err != nil {
return err
}
@ -2091,7 +2091,7 @@ func (i *IdentityStore) handleAliasListCommon(ctx context.Context, groupAlias bo
aliasInfoEntry["mount_path"] = mi.MountPath
} else {
mi = mountInfo{}
if mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
if mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
mi.MountType = mountValidationResp.MountType
mi.MountPath = mountValidationResp.MountPath
aliasInfoEntry["mount_type"] = mi.MountType

View File

@ -60,10 +60,11 @@ const (
identityMountPath = "identity/"
cubbyholeMountPath = "cubbyhole/"
systemMountType = "system"
identityMountType = "identity"
cubbyholeMountType = "cubbyhole"
pluginMountType = "plugin"
systemMountType = "system"
identityMountType = "identity"
cubbyholeMountType = "cubbyhole"
pluginMountType = "plugin"
mountTypeNSCubbyhole = "ns_cubbyhole"
MountTableUpdateStorage = true
MountTableNoUpdateStorage = false

View File

@ -8,10 +8,6 @@ import (
var NamespaceByID func(context.Context, string, *Core) (*namespace.Namespace, error) = namespaceByID
const (
mountTypeNSCubbyhole = "ns_cubbyhole"
)
func namespaceByID(ctx context.Context, nsID string, c *Core) (*namespace.Namespace, error) {
if nsID == namespace.RootNamespaceID {
return namespace.RootNamespace, nil

17
vault/namespaces_oss.go Normal file
View File

@ -0,0 +1,17 @@
// +build !enterprise
package vault
import (
"context"
"github.com/hashicorp/vault/helper/namespace"
)
func (c *Core) NamespaceByID(ctx context.Context, nsID string) (*namespace.Namespace, error) {
return namespaceByID(ctx, nsID, c)
}
func (c *Core) ListNamespaces() []*namespace.Namespace {
return []*namespace.Namespace{namespace.RootNamespace}
}

View File

@ -59,7 +59,7 @@ type routeEntry struct {
l sync.RWMutex
}
type validateMountResponse struct {
type ValidateMountResponse struct {
MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"`
MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"`
MountPath string `json:"mount_path" structs:"mount_path" mapstructure:"mount_path"`
@ -75,9 +75,9 @@ func (r *Router) reset() {
r.mountAccessorCache = radix.New()
}
// validateMountByAccessor returns the mount type and ID for a given mount
// ValidateMountByAccessor returns the mount type and ID for a given mount
// accessor
func (r *Router) validateMountByAccessor(accessor string) *validateMountResponse {
func (r *Router) ValidateMountByAccessor(accessor string) *ValidateMountResponse {
if accessor == "" {
return nil
}
@ -92,7 +92,7 @@ func (r *Router) validateMountByAccessor(accessor string) *validateMountResponse
mountPath = credentialRoutePrefix + mountPath
}
return &validateMountResponse{
return &ValidateMountResponse{
MountAccessor: mountEntry.Accessor,
MountType: mountEntry.Type,
MountPath: mountPath,