VAULT-6614 Enable role based quotas for lease-count quotas (OSS) (#16157)

* VAULT-6613 add DetermineRoleFromLoginRequest function to Core

* Fix body handling

* Role resolution for rate limit quotas

* VAULT-6613 update precedence test

* Add changelog

* VAULT-6614 start of changes for roles in LCQs

* Expiration changes for leases

* Add role information to RequestAuth

* VAULT-6614 Test updates

* VAULT-6614 Add expiration test with roles

* VAULT-6614 fix comment

* VAULT-6614 Protobuf on OSS

* VAULT-6614 Add rlock to determine role code

* VAULT-6614 Try lock instead of rlock

* VAULT-6614 back to rlock while I think about this more

* VAULT-6614 Additional safety for nil dereference

* VAULT-6614 Use %q over %s

* VAULT-6614 Add overloading to plugin backends

* VAULT-6614 RLocks instead

* VAULT-6614 Fix return for backend factory
This commit is contained in:
Violet Hynes 2022-07-05 13:02:00 -04:00 committed by GitHub
parent 752c7374a9
commit 0c80ee5cf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 399 additions and 110 deletions

View File

@ -7,6 +7,8 @@ import (
"reflect"
"sync"
log "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
@ -38,7 +40,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
// Backend returns an instance of the backend, either as a plugin if external
// or as a concrete implementation if builtin, casted as logical.Backend.
func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
func Backend(ctx context.Context, conf *logical.BackendConfig) (*PluginBackend, error) {
var b PluginBackend
name := conf.Config["plugin_name"]
@ -80,7 +82,7 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
// PluginBackend is a thin wrapper around plugin.BackendPluginClient
type PluginBackend struct {
logical.Backend
Backend logical.Backend
sync.RWMutex
config *logical.BackendConfig
@ -118,12 +120,12 @@ func (b *PluginBackend) startBackend(ctx context.Context, storage logical.Storag
if !b.loaded {
if b.Backend.Type() != nb.Type() {
nb.Cleanup(ctx)
b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType)
b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType)
return ErrMismatchType
}
if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) {
nb.Cleanup(ctx)
b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths)
b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths)
return ErrMismatchPaths
}
}
@ -169,7 +171,7 @@ func (b *PluginBackend) lazyLoadBackend(ctx context.Context, storage logical.Sto
// Reload plugin if it's an rpc.ErrShutdown
b.Lock()
if b.canary == canary {
b.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"])
b.Backend.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"])
err := b.startBackend(ctx, storage)
if err != nil {
b.Unlock()
@ -220,3 +222,52 @@ func (b *PluginBackend) HandleExistenceCheck(ctx context.Context, req *logical.R
func (b *PluginBackend) Initialize(ctx context.Context, req *logical.InitializationRequest) error {
return nil
}
// SpecialPaths is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) SpecialPaths() *logical.Paths {
b.RLock()
defer b.RUnlock()
return b.Backend.SpecialPaths()
}
// System is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) System() logical.SystemView {
b.RLock()
defer b.RUnlock()
return b.Backend.System()
}
// Logger is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) Logger() log.Logger {
b.RLock()
defer b.RUnlock()
return b.Backend.Logger()
}
// Cleanup is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) Cleanup(ctx context.Context) {
b.RLock()
defer b.RUnlock()
b.Backend.Cleanup(ctx)
}
// InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) InvalidateKey(ctx context.Context, key string) {
b.RLock()
defer b.RUnlock()
b.Backend.InvalidateKey(ctx, key)
}
// Setup is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) Setup(ctx context.Context, config *logical.BackendConfig) error {
b.RLock()
defer b.RUnlock()
return b.Backend.Setup(ctx, config)
}
// Type is a thin wrapper used to ensure we grab the lock for race purposes
func (b *PluginBackend) Type() logical.BackendType {
b.RLock()
defer b.RUnlock()
return b.Backend.Type()
}

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: helper/forwarding/types.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: helper/identity/mfa/types.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: helper/identity/types.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: helper/storagepacker/types.proto

View File

@ -64,7 +64,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
Type: quotas.TypeRateLimit,
Path: path,
MountPath: mountPath,
Role: core.DetermineRoleFromLoginRequest(mountPath, bodyBytes, r.Context()),
Role: core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()),
NamespacePath: ns.Path,
ClientAddress: parseRemoteIPAddress(r),
})

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: physical/raft/types.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/database/dbplugin/database.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/database/dbplugin/v5/proto/database.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/helper/pluginutil/multiplexing.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/logical/identity.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/logical/plugin.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: sdk/plugin/pb/backend.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: vault/activity/activity_log.proto

View File

@ -175,7 +175,7 @@ func (e *ErrInvalidKey) Error() string {
return fmt.Sprintf("invalid key: %v", e.Reason)
}
type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth) error
type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth, string) error
type activeAdvertisement struct {
RedirectAddr string `json:"redirect_addr"`
@ -3324,15 +3324,9 @@ func (c *Core) CheckPluginPerms(pluginName string) (err error) {
return err
}
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
// login request
func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, payload []byte, ctx context.Context) string {
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential {
// Role based quotas do not apply to this request
return ""
}
// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given
// login request, accepting a byte payload
func (c *Core) DetermineRoleFromLoginRequestFromBytes(mountPoint string, payload []byte, ctx context.Context) string {
data := make(map[string]interface{})
err := jsonutil.DecodeJSON(payload, &data)
if err != nil {
@ -3340,6 +3334,20 @@ func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, payload []byte,
return ""
}
return c.DetermineRoleFromLoginRequest(mountPoint, data, ctx)
}
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
// login request
func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, data map[string]interface{}, ctx context.Context) string {
c.authLock.RLock()
defer c.authLock.RUnlock()
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential {
// Role based quotas do not apply to this request
return ""
}
resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{
MountPoint: mountPoint,
Path: "login",

View File

@ -166,7 +166,7 @@ func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quot
return nil
}
func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error {
func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leases []*quotas.QuotaLeaseInformation) error {
return nil
}

View File

@ -471,9 +471,15 @@ func (m *ExpirationManager) invalidate(key string) {
m.pending.Delete(leaseID)
m.leaseCount--
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
m.logger.Error("failed to update quota on lease invalidation", "error", err)
return
// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to
// accurately update quota lease information.
// Note that cachedLeaseInfo should never be nil under normal operation.
if pending.cachedLeaseInfo != nil {
leaseInfo := &quotas.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole}
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil {
m.logger.Error("failed to update quota on lease invalidation", "error", err)
return
}
}
default:
// Update the lease in memory
@ -486,14 +492,21 @@ func (m *ExpirationManager) invalidate(key string) {
// other maps, and update metrics/quotas if appropriate.
m.nonexpiring.Delete(leaseID)
if _, ok := m.irrevocable.Load(leaseID); ok {
if info, ok := m.irrevocable.Load(leaseID); ok {
irrevocable := info.(pendingInfo)
m.irrevocable.Delete(leaseID)
m.irrevocableLeaseCount--
m.leaseCount--
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
m.logger.Error("failed to update quota on lease invalidation", "error", err)
return
// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to
// accurately update quota lease information.
// Note that cachedLeaseInfo should never be nil under normal operation.
if irrevocable.cachedLeaseInfo != nil {
leaseInfo := &quotas.QuotaLeaseInformation{LeaseId: leaseID, Role: irrevocable.cachedLeaseInfo.LoginRole}
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil {
m.logger.Error("failed to update quota on lease invalidation", "error", err)
return
}
}
}
return
@ -1389,7 +1402,7 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request
// Register is used to take a request and response with an associated
// lease. The secret gets assigned a LeaseID and the management of
// of lease is assumed by the expiration manager.
func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response) (id string, retErr error) {
func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response, loginRole string) (id string, retErr error) {
defer metrics.MeasureSince([]string{"expire", "register"}, time.Now())
te := req.TokenEntry()
@ -1431,6 +1444,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
Path: req.Path,
Data: resp.Data,
Secret: resp.Secret,
LoginRole: loginRole,
IssueTime: time.Now(),
ExpireTime: resp.Secret.ExpirationTime(),
namespace: ns,
@ -1524,7 +1538,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
// RegisterAuth is used to take an Auth response with an associated lease.
// The token does not get a LeaseID, but the lease management is handled by
// the expiration manager.
func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth) error {
func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth, loginRole string) error {
defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now())
// Triggers failure of RegisterAuth. This should only be set and triggered
@ -1576,6 +1590,7 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE
ClientToken: auth.ClientToken,
Auth: auth,
Path: te.Path,
LoginRole: loginRole,
IssueTime: time.Now(),
ExpireTime: authExpirationTime,
namespace: tokenNS,
@ -1721,6 +1736,7 @@ func (m *ExpirationManager) inMemoryLeaseInfo(le *leaseEntry) *leaseEntry {
if le.isIrrevocable() {
ret.RevokeErr = le.RevokeErr
}
ret.LoginRole = le.LoginRole
return ret
}
@ -1795,9 +1811,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) {
info.(pendingInfo).timer.Stop()
m.pending.Delete(le.LeaseID)
m.leaseCount--
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil {
m.logger.Error("failed to update quota on lease deletion", "error", err)
return
// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to
// accurately update quota lease information.
// Note that cachedLeaseInfo should never be nil under normal operation.
if pending.cachedLeaseInfo != nil {
leaseInfo := &quotas.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole}
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil {
m.logger.Error("failed to update quota on lease deletion", "error", err)
return
}
}
}
return
@ -1849,9 +1871,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) {
if leaseCreated {
m.leaseCount++
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil {
m.logger.Error("failed to update quota on lease creation", "error", err)
return
// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to
// accurately update quota lease information.
// Note that cachedLeaseInfo should never be nil under normal operation.
if pending.cachedLeaseInfo != nil {
leaseInfo := &quotas.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole}
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil {
m.logger.Error("failed to update quota on lease creation", "error", err)
return
}
}
}
}
@ -2450,9 +2478,15 @@ func (m *ExpirationManager) removeFromPending(ctx context.Context, leaseID strin
m.pending.Delete(leaseID)
if decrementCounters {
m.leaseCount--
// Log but do not fail; unit tests (and maybe Tidy on production systems)
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
m.logger.Error("failed to update quota on revocation", "error", err)
// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to
// accurately update quota lease information.
// Note that cachedLeaseInfo should never be nil under normal operation.
if pending.cachedLeaseInfo != nil {
leaseInfo := &quotas.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole}
// Log but do not fail; unit tests (and maybe Tidy on production systems)
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil {
m.logger.Error("failed to update quota on revocation", "error", err)
}
}
}
}
@ -2663,6 +2697,11 @@ type leaseEntry struct {
ExpireTime time.Time `json:"expire_time"`
LastRenewalTime time.Time `json:"last_renewal_time"`
// LoginRole is used to indicate which login role (if applicable) this lease
// was created with. This is required to decrement lease count quotas
// based on login roles upon lease expiry.
LoginRole string `json:"login_role"`
// Version is used to track new different versions of leases. V0 (or
// zero-value) had non-root namespaced secondary indexes live in the root
// namespace, and V1 has secondary indexes live in the matching namespace.

View File

@ -324,6 +324,103 @@ func TestExpiration_TotalLeaseCount(t *testing.T) {
}
}
func TestExpiration_TotalLeaseCount_WithRoles(t *testing.T) {
// Quotas and internal lease count tracker are coupled, so this is a proxy
// for testing the total lease count quota
c, _, _ := TestCoreUnsealed(t)
exp := c.expiration
expectedCount := 0
otherNS := &namespace.Namespace{
ID: "nsid",
Path: "foo/bar",
}
for i := 0; i < 50; i++ {
le := &leaseEntry{
LeaseID: "lease" + fmt.Sprintf("%d", i),
Path: "foo/bar/" + fmt.Sprintf("%d", i),
LoginRole: "loginRole" + fmt.Sprintf("%d", i),
namespace: namespace.RootNamespace,
IssueTime: time.Now(),
ExpireTime: time.Now().Add(time.Hour),
}
otherNSle := &leaseEntry{
LeaseID: "lease" + fmt.Sprintf("%d", i) + "/blah.nsid",
Path: "foo/bar/" + fmt.Sprintf("%d", i) + "/blah.nsid",
LoginRole: "loginRole" + fmt.Sprintf("%d", i),
namespace: otherNS,
IssueTime: time.Now(),
ExpireTime: time.Now().Add(time.Hour),
}
exp.pendingLock.Lock()
if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil {
exp.pendingLock.Unlock()
t.Fatalf("error persisting irrevocable entry: %v", err)
}
exp.updatePendingInternal(le)
expectedCount++
if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil {
exp.pendingLock.Unlock()
t.Fatalf("error persisting irrevocable entry: %v", err)
}
exp.updatePendingInternal(otherNSle)
expectedCount++
exp.pendingLock.Unlock()
}
// add some irrevocable leases to each count to ensure they are counted too
// note: irrevocable leases almost certainly have an expire time set in the
// past, but for this exercise it should be fine to set it to whatever
for i := 50; i < 60; i++ {
le := &leaseEntry{
LeaseID: "lease" + fmt.Sprintf("%d", i+1),
Path: "foo/bar/" + fmt.Sprintf("%d", i+1),
LoginRole: "loginRole" + fmt.Sprintf("%d", i),
namespace: namespace.RootNamespace,
IssueTime: time.Now(),
ExpireTime: time.Now(),
RevokeErr: "some err message",
}
otherNSle := &leaseEntry{
LeaseID: "lease" + fmt.Sprintf("%d", i+1) + "/blah.nsid",
Path: "foo/bar/" + fmt.Sprintf("%d", i+1) + "/blah.nsid",
LoginRole: "loginRole" + fmt.Sprintf("%d", i),
namespace: otherNS,
IssueTime: time.Now(),
ExpireTime: time.Now(),
RevokeErr: "some err message",
}
exp.pendingLock.Lock()
if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil {
exp.pendingLock.Unlock()
t.Fatalf("error persisting irrevocable entry: %v", err)
}
exp.updatePendingInternal(le)
expectedCount++
if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil {
exp.pendingLock.Unlock()
t.Fatalf("error persisting irrevocable entry: %v", err)
}
exp.updatePendingInternal(otherNSle)
expectedCount++
exp.pendingLock.Unlock()
}
exp.pendingLock.RLock()
count := exp.leaseCount
exp.pendingLock.RUnlock()
if count != expectedCount {
t.Errorf("bad lease count. expected %d, got %d", expectedCount, count)
}
}
func TestExpiration_Tidy(t *testing.T) {
var err error
@ -477,7 +574,7 @@ func TestExpiration_Tidy(t *testing.T) {
"test_key": "test_value",
},
}
_, err := exp.Register(namespace.RootContext(nil), req, resp)
_, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -636,7 +733,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
"secret_key": "abcd",
},
}
_, err = exp.Register(namespace.RootContext(nil), req, resp)
_, err = exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
b.Fatalf("err: %v", err)
}
@ -698,7 +795,7 @@ func BenchmarkExpiration_Create_Leases(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
req.Path = fmt.Sprintf("prod/aws/%d", i)
_, err = exp.Register(namespace.RootContext(nil), req, resp)
_, err = exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
b.Fatalf("err: %v", err)
}
@ -743,7 +840,7 @@ func TestExpiration_Restore(t *testing.T) {
"secret_key": "abcd",
},
}
_, err := exp.Register(namespace.RootContext(nil), req, resp)
_, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -815,7 +912,7 @@ func TestExpiration_Register(t *testing.T) {
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp)
id, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -829,6 +926,49 @@ func TestExpiration_Register(t *testing.T) {
}
}
func TestExpiration_Register_Role(t *testing.T) {
exp := mockExpiration(t)
role := "role1"
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "prod/aws/foo",
ClientToken: "foobar",
}
req.SetTokenEntry(&logical.TokenEntry{ID: "foobar", NamespaceID: "root"})
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: time.Hour,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp, role)
if err != nil {
t.Fatalf("err: %v", err)
}
if !strings.HasPrefix(id, req.Path) {
t.Fatalf("bad: %s", id)
}
if len(id) <= len(req.Path) {
t.Fatalf("bad: %s", id)
}
le, err := exp.loadEntry(exp.quitContext, id)
if err != nil {
t.Fatalf("err: %v", err)
}
if le.LoginRole != role {
t.Fatalf("Login role incorrect. Expected %s, received %s", role, le.LoginRole)
}
}
func TestExpiration_Register_BatchToken(t *testing.T) {
c, _, rootToken := TestCoreUnsealed(t)
exp := c.expiration
@ -883,7 +1023,7 @@ func TestExpiration_Register_BatchToken(t *testing.T) {
},
}
leaseID, err := exp.Register(namespace.RootContext(nil), req, resp)
leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -952,7 +1092,7 @@ func TestExpiration_RegisterAuth(t *testing.T) {
Path: "auth/github/login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -961,7 +1101,41 @@ func TestExpiration_RegisterAuth(t *testing.T) {
Path: "auth/github/../login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err == nil {
t.Fatal("expected error")
}
}
func TestExpiration_RegisterAuth_Role(t *testing.T) {
exp := mockExpiration(t)
role := "role1"
root, err := exp.tokenStore.rootToken(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
auth := &logical.Auth{
ClientToken: root.ID,
LeaseOptions: logical.LeaseOptions{
TTL: time.Hour,
},
}
te := &logical.TokenEntry{
Path: "auth/github/login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role)
if err != nil {
t.Fatalf("err: %v", err)
}
te = &logical.TokenEntry{
Path: "auth/github/../login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role)
if err == nil {
t.Fatal("expected error")
}
@ -985,7 +1159,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) {
Policies: []string{"root"},
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1034,13 +1208,13 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) {
}
// First on core
err = c.RegisterAuth(ctx, 0, "auth/github/login", auth)
err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "")
if err != nil {
t.Fatal(err)
}
auth.TokenPolicies[0] = "default"
err = c.RegisterAuth(ctx, 0, "auth/github/login", auth)
err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "")
if err == nil {
t.Fatal("expected error")
}
@ -1053,14 +1227,14 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) {
Policies: []string{"root"},
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(ctx, te, auth)
err = exp.RegisterAuth(ctx, te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
// Test non-root token with zero TTL
te.Policies = []string{"default"}
err = exp.RegisterAuth(ctx, te, auth)
err = exp.RegisterAuth(ctx, te, auth, "")
if err == nil {
t.Fatal("expected error")
}
@ -1098,7 +1272,7 @@ func TestExpiration_Revoke(t *testing.T) {
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp)
id, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1145,7 +1319,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) {
},
}
_, err = exp.Register(namespace.RootContext(nil), req, resp)
_, err = exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1208,7 +1382,7 @@ func TestExpiration_RevokePrefix(t *testing.T) {
"secret_key": "abcd",
},
}
_, err := exp.Register(namespace.RootContext(nil), req, resp)
_, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1277,7 +1451,7 @@ func TestExpiration_RevokeByToken(t *testing.T) {
"secret_key": "abcd",
},
}
_, err := exp.Register(namespace.RootContext(nil), req, resp)
_, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1376,7 +1550,7 @@ func TestExpiration_RevokeByToken_Blocking(t *testing.T) {
"secret_key": "abcd",
},
}
_, err := exp.Register(namespace.RootContext(nil), req, resp)
_, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1448,7 +1622,7 @@ func TestExpiration_RenewToken(t *testing.T) {
Path: "auth/token/login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1497,7 +1671,7 @@ func TestExpiration_RenewToken_period(t *testing.T) {
Path: "auth/token/login",
NamespaceID: namespace.RootNamespaceID,
}
err := exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err := exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1578,7 +1752,7 @@ func TestExpiration_RenewToken_period_backend(t *testing.T) {
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1635,7 +1809,7 @@ func TestExpiration_RenewToken_NotRenewable(t *testing.T) {
Path: "auth/foo/login",
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1688,7 +1862,7 @@ func TestExpiration_Renew(t *testing.T) {
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp)
id, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1759,7 +1933,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) {
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp)
id, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1810,7 +1984,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) {
},
}
id, err := exp.Register(namespace.RootContext(nil), req, resp)
id, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1887,7 +2061,7 @@ func TestExpiration_Renew_FinalSecond(t *testing.T) {
}
ctx := namespace.RootContext(nil)
id, err := exp.Register(ctx, req, resp)
id, err := exp.Register(ctx, req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1962,7 +2136,7 @@ func TestExpiration_Renew_FinalSecond_Lease(t *testing.T) {
}
ctx := namespace.RootContext(nil)
id, err := exp.Register(ctx, req, resp)
id, err := exp.Register(ctx, req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -2647,7 +2821,7 @@ func sampleToken(t *testing.T, exp *ExpirationManager, path string, expiring boo
Policies: auth.Policies,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -2822,7 +2996,7 @@ func registerOneLease(t *testing.T, ctx context.Context, exp *ExpirationManager)
},
}
leaseID, err := exp.Register(ctx, req, resp)
leaseID, err := exp.Register(ctx, req, resp, "")
if err != nil {
t.Fatal(err)
}

View File

@ -211,7 +211,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
}
authBackend := b.Core.router.MatchingBackend(namespace.ContextWithNamespace(ctx, ns), mountPath)
if authBackend == nil || authBackend.Type() != logical.TypeCredential {
return logical.ErrorResponse("Mount path '%s' is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil
return logical.ErrorResponse("Mount path %q is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil
}
// We will always error as we aren't supplying real data, but we're looking for "unsupported operation" in particular
_, err := authBackend.HandleRequest(ctx, &logical.Request{
@ -219,7 +219,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
Operation: logical.ResolveRoleOperation,
})
if err != nil && (err == logical.ErrUnsupportedOperation || err == logical.ErrUnsupportedPath) {
return logical.ErrorResponse("Mount path '%s' does not support use with role-based quotas", mountPath), nil
return logical.ErrorResponse("Mount path %q does not support use with role-based quotas", mountPath), nil
}
}

View File

@ -1708,7 +1708,7 @@ func TestSystemBackend_revokePrefixAuth_newUrl(t *testing.T) {
TTL: time.Hour,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1772,7 +1772,7 @@ func TestSystemBackend_revokePrefixAuth_origUrl(t *testing.T) {
TTL: time.Hour,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -3617,7 +3617,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) {
ClientToken: te.ID,
Accessor: te.Accessor,
Orphan: true,
}); err != nil {
}, ""); err != nil {
t.Fatal(err)
}

View File

@ -716,7 +716,7 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic
}
// MFA validation has passed. Let's generate the token
resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth)
resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth, req.Data)
if err != nil {
return nil, fmt.Errorf("failed to create a token. error: %v", err)
}
@ -742,7 +742,7 @@ func (c *Core) teardownLoginMFA() error {
// LoginMFACreateToken creates a token after the login MFA is validated.
// It also applies the lease quotas on the original login request path.
func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) {
func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth, loginRequestData map[string]interface{}) (*logical.Response, error) {
auth := cachedAuth
resp := &logical.Response{
Auth: auth,
@ -761,6 +761,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu
quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, &quotas.Request{
Path: reqPath,
MountPath: strings.TrimPrefix(mountPoint, ns.Path),
Role: c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx),
NamespacePath: ns.Path,
})
@ -780,7 +781,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu
// note that we don't need to handle the error for the following function right away.
// The function takes the response as in input variable and modify it. So, the returned
// arguments are resp and err.
leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp)
leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp, loginRequestData)
if quotaResp.Access != nil {
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)

View File

@ -168,6 +168,17 @@ type Manager struct {
lock *sync.RWMutex
}
// QuotaLeaseInformation contains all of the information lease-count quotas require
// from a lease to uniquely identify the lease-count quota to increment/decrement
type QuotaLeaseInformation struct {
// We can determine path and namespace from leaseId
LeaseId string
// We need the role as it's not part of the leaseId, and is required
// to uniquely identify a lease count quota
Role string
}
// Quota represents the common properties of every quota type
type Quota interface {
// allow checks the if the request is allowed by the quota type implementation.

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: vault/request_forwarding_service.proto

View File

@ -969,9 +969,11 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
}
leaseGenerated := false
loginRole := c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)
quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, &quotas.Request{
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
Role: loginRole,
NamespacePath: ns.Path,
})
if quotaErr != nil {
@ -1111,7 +1113,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
return nil, auth, retErr
}
leaseID, err := registerFunc(ctx, req, resp)
leaseID, err := registerFunc(ctx, req, resp, loginRole)
if err != nil {
c.logger.Error("failed to register lease", "request_path", req.Path, "error", err)
retErr = multierror.Append(retErr, ErrInternalError)
@ -1191,7 +1193,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
Path: resp.Auth.CreationPath,
NamespaceID: ns.ID,
}
if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth); err != nil {
if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil {
// Best-effort clean up on error, so we log the cleanup error as
// a warning but still return as internal error.
if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil {
@ -1390,6 +1392,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, &quotas.Request{
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
Role: c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx),
NamespacePath: ns.Path,
})
@ -1576,7 +1579,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// Attach the display name, might be used by audit backends
req.DisplayName = auth.DisplayName
leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp)
leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp, req.Data)
leaseGenerated = leaseGen
if errCreateToken != nil {
return respTokenCreate, nil, errCreateToken
@ -1607,7 +1610,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// LoginCreateToken creates a token as a result of a login request.
// If MFA is enforced, mfa/validate endpoint calls this functions
// after successful MFA validation to generate the token.
func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response) (bool, *logical.Response, error) {
func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response, loginRequestData map[string]interface{}) (bool, *logical.Response, error) {
auth := resp.Auth
source := strings.TrimPrefix(mountPoint, credentialRoutePrefix)
@ -1669,7 +1672,7 @@ func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, re
}
leaseGenerated := false
err = registerFunc(ctx, tokenTTL, reqPath, auth)
err = registerFunc(ctx, tokenTTL, reqPath, auth, c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx))
switch {
case err == nil:
if auth.TokenType != logical.TokenTypeBatch {
@ -1736,7 +1739,9 @@ func blockRequestIfErrorImpl(_ *Core, _, _ string) error { return nil }
// RegisterAuth uses a logical.Auth object to create a token entry in the token
// store, and registers a corresponding token lease to the expiration manager.
func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth) error {
// role is the login role used as part of the creation of the token entry. If not
// relevant, can be omitted (by being provided as "").
func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth, role string) error {
// We first assign token policies to what was returned from the backend
// via auth.Policies. Then, we get the full set of policies into
// auth.Policies from the backend + entity information -- this is not
@ -1786,7 +1791,7 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st
auth.Renewable = false
case logical.TokenTypeService:
// Register with the expiration manager
if err := c.expiration.RegisterAuth(ctx, &te, auth); err != nil {
if err := c.expiration.RegisterAuth(ctx, &te, auth, role); err != nil {
if err := c.tokenStore.revokeOrphan(ctx, te.ID); err != nil {
c.logger.Warn("failed to clean up token lease during login request", "request_path", path, "error", err)
}

View File

@ -42,7 +42,7 @@ func forward(ctx context.Context, c *Core, req *logical.Request) (*logical.Respo
panic("forward called in OSS Vault")
}
func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response) (string, error), error) {
func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response, string) (string, error), error) {
return c.expiration.Register, nil
}

View File

@ -330,7 +330,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
NamespaceID: namespace.RootNamespaceID,
}
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth); err != nil {
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth, ""); err != nil {
t.Fatal(err)
}
@ -375,7 +375,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
},
ClientToken: ent.ID,
}
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil {
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil {
t.Fatal(err)
}
@ -420,7 +420,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
},
ClientToken: ent.ID,
}
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil {
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil {
t.Fatal(err)
}
@ -462,7 +462,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
},
ClientToken: ent.ID,
}
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil {
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil {
t.Fatal(err)
}
@ -496,7 +496,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) {
},
ClientToken: ent.ID,
}
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil {
if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil {
t.Fatal(err)
}
@ -572,7 +572,7 @@ func testMakeTokenViaRequestContext(t testing.TB, ctx context.Context, ts *Token
}
if resp.Auth.TokenType != logical.TokenTypeBatch {
if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth); err != nil {
if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth, ""); err != nil {
t.Fatal(err)
}
}
@ -618,7 +618,7 @@ func testMakeTokenDirectly(t testing.TB, ts *TokenStore, te *logical.TokenEntry)
CreationPath: te.Path,
TokenType: te.Type,
}
err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth)
err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth, "")
switch err {
case nil:
if te.Type == logical.TokenTypeBatch {
@ -861,7 +861,7 @@ func TestTokenStore_HandleRequest_Renew_Revoke_Accessor(t *testing.T) {
t.Fatal("token entry was nil")
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -1322,7 +1322,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) {
"secret_key": "abcd",
},
}
leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp)
leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -2208,7 +2208,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) {
Renewable: true,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -2230,7 +2230,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) {
Renewable: true,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -2623,7 +2623,7 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) {
Renewable: true,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), root, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -3113,7 +3113,7 @@ func TestTokenStore_HandleRequest_RenewSelf(t *testing.T) {
Renewable: true,
},
}
err = exp.RegisterAuth(namespace.RootContext(nil), root, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -5787,7 +5787,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) {
NamespaceID: namespace.RootNamespaceID,
}
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth)
err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -5820,7 +5820,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) {
leases := []string{}
for i := 0; i < 10; i++ {
leaseID, err := exp.Register(namespace.RootContext(nil), req, resp)
leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "")
if err != nil {
t.Fatal(err)
}

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: vault/tokens/token.proto

View File

@ -325,7 +325,7 @@ DONELISTHANDLING:
}
// Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(ctx, &te, wAuth); err != nil {
if err := c.expiration.RegisterAuth(ctx, &te, wAuth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)