Resource Quotas: Rate Limiting (#9330)
This commit is contained in:
parent
ae821b2600
commit
c6876fe00f
|
@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error {
|
|||
// body must still be closed manually.
|
||||
func (r *Response) Error() error {
|
||||
// 200 to 399 are okay status codes. 429 is the code for health status of
|
||||
// standby nodes.
|
||||
if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 {
|
||||
// standby nodes, otherwise, 429 is treated as quota limit reached.
|
||||
if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
1
go.mod
1
go.mod
|
@ -146,6 +146,7 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9
|
||||
golang.org/x/net v0.0.0-20200602114024-627f9648deb9
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1
|
||||
golang.org/x/tools v0.0.0-20200416214402-fc959738d646
|
||||
google.golang.org/api v0.24.0
|
||||
google.golang.org/grpc v1.29.1
|
||||
|
|
|
@ -176,8 +176,8 @@ func Handler(props *vault.HandlerProperties) http.Handler {
|
|||
// Wrap the handler in another handler to trigger all help paths.
|
||||
helpWrappedHandler := wrapHelpHandler(mux, core)
|
||||
corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core)
|
||||
|
||||
genericWrappedHandler := genericWrapping(core, corsWrappedHandler, props)
|
||||
quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core)
|
||||
genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props)
|
||||
|
||||
// Wrap the handler with PrintablePathCheckHandler to check for non-printable
|
||||
// characters in the request path.
|
||||
|
@ -221,26 +221,14 @@ func (w *copyResponseWriter) WriteHeader(code int) {
|
|||
|
||||
func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origBody := new(bytes.Buffer)
|
||||
reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody))
|
||||
r.Body = reader
|
||||
req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
|
||||
if err != nil || status != 0 {
|
||||
respondError(w, status, err)
|
||||
return
|
||||
}
|
||||
if origBody != nil {
|
||||
r.Body = ioutil.NopCloser(origBody)
|
||||
}
|
||||
input := &logical.LogInput{
|
||||
Request: req,
|
||||
Request: w.(*LogicalResponseWriter).request,
|
||||
}
|
||||
|
||||
core.AuditLogger().AuditRequest(r.Context(), input)
|
||||
cw := newCopyResponseWriter(w)
|
||||
h.ServeHTTP(cw, r)
|
||||
data := make(map[string]interface{})
|
||||
err = jsonutil.DecodeJSON(cw.body.Bytes(), &data)
|
||||
err := jsonutil.DecodeJSON(cw.body.Bytes(), &data)
|
||||
if err != nil {
|
||||
// best effort, ignore
|
||||
}
|
||||
|
@ -249,7 +237,13 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
|
|||
core.AuditLogger().AuditResponse(r.Context(), input)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// LogicalResponseWriter is used to carry the logical request from generic
|
||||
// handler down to all the middleware http handlers.
|
||||
type LogicalResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
request *logical.Request
|
||||
}
|
||||
|
||||
// wrapGenericHandler wraps the handler with an extra layer of handler where
|
||||
|
@ -288,6 +282,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
|||
}
|
||||
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
|
||||
r = r.WithContext(ctx)
|
||||
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/"):
|
||||
|
@ -306,7 +301,27 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
|||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
origBody := new(bytes.Buffer)
|
||||
reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody))
|
||||
r.Body = reader
|
||||
req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
|
||||
if err != nil || status != 0 {
|
||||
respondError(w, status, err)
|
||||
return
|
||||
}
|
||||
// Reset the body since logical request creation already read the
|
||||
// request body.
|
||||
r.Body = ioutil.NopCloser(origBody)
|
||||
|
||||
// Set the mount path in the request
|
||||
req.MountPoint = core.MatchingMount(r.Context(), req.Path)
|
||||
|
||||
// Pass the logical request down through the response writer
|
||||
h.ServeHTTP(&LogicalResponseWriter{
|
||||
ResponseWriter: w,
|
||||
request: req,
|
||||
}, r)
|
||||
|
||||
cancelFunc()
|
||||
return
|
||||
})
|
||||
|
|
|
@ -141,6 +141,7 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.
|
|||
}
|
||||
|
||||
case "OPTIONS":
|
||||
case "HEAD":
|
||||
default:
|
||||
return nil, nil, http.StatusMethodNotAllowed, nil
|
||||
}
|
||||
|
@ -169,36 +170,32 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.
|
|||
return req, origBody, 0, nil
|
||||
}
|
||||
|
||||
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
|
||||
req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
|
||||
if err != nil || status != 0 {
|
||||
return nil, nil, status, err
|
||||
}
|
||||
|
||||
func setupLogicalRequest(core *vault.Core, req *logical.Request, r *http.Request) (*logical.Request, int, error) {
|
||||
var err error
|
||||
req, err = requestAuth(core, r, req)
|
||||
if err != nil {
|
||||
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
return nil, nil, http.StatusForbidden, nil
|
||||
return nil, http.StatusForbidden, nil
|
||||
}
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
|
||||
}
|
||||
|
||||
req, err = requestWrapInfo(r, req)
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
}
|
||||
|
||||
err = parseMFAHeader(req)
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
|
||||
}
|
||||
|
||||
err = requestPolicyOverride(r, req)
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
|
||||
}
|
||||
|
||||
return req, origBody, 0, nil
|
||||
return req, 0, nil
|
||||
}
|
||||
|
||||
// handleLogical returns a handler for processing logical requests. These requests
|
||||
|
@ -257,7 +254,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han
|
|||
// toggles. Refer to usage on functions for possible behaviors.
|
||||
func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, origBody, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
|
@ -270,10 +267,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
|
|||
respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly)
|
||||
return
|
||||
}
|
||||
|
||||
if origBody != nil {
|
||||
r.Body = origBody
|
||||
}
|
||||
forwardRequest(core, w, r)
|
||||
return
|
||||
}
|
||||
|
@ -398,9 +391,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
|
|||
respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly)
|
||||
return
|
||||
case needsForward && !noForward:
|
||||
if origBody != nil {
|
||||
r.Body = origBody
|
||||
}
|
||||
forwardRequest(core, w, r)
|
||||
return
|
||||
case !ok:
|
||||
|
|
|
@ -281,7 +281,13 @@ func TestLogical_ListSuffix(t *testing.T) {
|
|||
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||
req = req.WithContext(namespace.RootContext(nil))
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
lreq, _, status, err := buildLogicalRequest(core, nil, req)
|
||||
|
||||
lreq, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
|
||||
if err != nil || status != 0 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lreq, status, err = setupLogicalRequest(core, lreq, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -295,7 +301,11 @@ func TestLogical_ListSuffix(t *testing.T) {
|
|||
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
|
||||
req = req.WithContext(namespace.RootContext(nil))
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
lreq, _, status, err = buildLogicalRequest(core, nil, req)
|
||||
lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
|
||||
if err != nil || status != 0 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lreq, status, err = setupLogicalRequest(core, lreq, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -309,7 +319,11 @@ func TestLogical_ListSuffix(t *testing.T) {
|
|||
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||
req = req.WithContext(namespace.RootContext(nil))
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
lreq, _, status, err = buildLogicalRequest(core, nil, req)
|
||||
lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
|
||||
if err != nil || status != 0 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lreq, status, err = setupLogicalRequest(core, lreq, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
func handleSysSeal(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, _, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
|
@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
|
|||
|
||||
func handleSysStepDown(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, _, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
|
|
61
http/util.go
61
http/util.go
|
@ -1,15 +1,21 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
)
|
||||
|
||||
var (
|
||||
adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) {
|
||||
return r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)), 0
|
||||
return r, 0
|
||||
}
|
||||
|
||||
genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler {
|
||||
|
@ -22,3 +28,56 @@ var (
|
|||
|
||||
nonVotersAllowed = false
|
||||
)
|
||||
|
||||
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ns, err := namespace.FromContext(r.Context())
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
req := w.(*LogicalResponseWriter).request
|
||||
quotaResp, err := core.ApplyRateLimitQuota("as.Request{
|
||||
Type: quotas.TypeRateLimit,
|
||||
Path: req.Path,
|
||||
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
|
||||
NamespacePath: ns.Path,
|
||||
ClientAddress: parseRemoteIPAddress(r),
|
||||
})
|
||||
if err != nil {
|
||||
core.Logger().Error("failed to apply quota", "path", req.Path, "error", err)
|
||||
respondError(w, http.StatusUnprocessableEntity, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !quotaResp.Allowed {
|
||||
quotaErr := errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrRateLimitQuotaExceeded)
|
||||
respondError(w, http.StatusTooManyRequests, quotaErr)
|
||||
|
||||
if core.Logger().IsTrace() {
|
||||
core.Logger().Trace("request rejected due to lease count quota violation", "request_path", req.Path)
|
||||
}
|
||||
|
||||
if core.RateLimitAuditLoggingEnabled() {
|
||||
_ = core.AuditLogger().AuditRequest(r.Context(), &logical.LogInput{
|
||||
Request: req,
|
||||
OuterErr: quotaErr,
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
func parseRemoteIPAddress(r *http.Request) string {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
|
|
@ -28,6 +28,14 @@ var (
|
|||
// ErrPerfStandbyForward is returned when Vault is in a state such that a
|
||||
// perf standby cannot satisfy a request
|
||||
ErrPerfStandbyPleaseForward = errors.New("please forward to the active node")
|
||||
|
||||
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
|
||||
// count quota being exceeded.
|
||||
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
|
||||
|
||||
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
|
||||
// rate limit quota being exceeded.
|
||||
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
|
||||
)
|
||||
|
||||
type HTTPCodedError interface {
|
||||
|
|
|
@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
|||
}
|
||||
})
|
||||
if allErrors != nil {
|
||||
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
|
||||
return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors)
|
||||
}
|
||||
return codedErr.Code, errors.New(codedErr.Msg)
|
||||
}
|
||||
|
@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
|||
statusCode = http.StatusBadRequest
|
||||
case errwrap.Contains(err, ErrUpstreamRateLimited.Error()):
|
||||
statusCode = http.StatusBadGateway
|
||||
case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()):
|
||||
statusCode = http.StatusTooManyRequests
|
||||
case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()):
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -339,6 +339,11 @@ func (c *Core) disableCredentialInternal(ctx context.Context, path string, updat
|
|||
|
||||
removePathCheckers(c, entry, viewPath)
|
||||
|
||||
if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil {
|
||||
c.logger.Error("failed to update quotas after disabling auth", "path", path, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("disabled credential backend", "path", path)
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ import (
|
|||
sr "github.com/hashicorp/vault/serviceregistration"
|
||||
"github.com/hashicorp/vault/shamir"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
vaultseal "github.com/hashicorp/vault/vault/seal"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -97,6 +98,7 @@ var (
|
|||
enterprisePostUnseal = enterprisePostUnsealImpl
|
||||
enterprisePreSeal = enterprisePreSealImpl
|
||||
enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl
|
||||
enterpriseSetupQuotas = enterpriseSetupQuotasImpl
|
||||
startReplication = startReplicationImpl
|
||||
stopReplication = stopReplicationImpl
|
||||
LastWAL = lastWALImpl
|
||||
|
@ -520,6 +522,8 @@ type Core struct {
|
|||
// can test an upgrade to a version that includes the fixes from
|
||||
// https://github.com/hashicorp/vault-enterprise/pull/1103
|
||||
PR1103disabled bool
|
||||
|
||||
quotaManager *quotas.Manager
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -944,7 +948,9 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
c.clusterListener.Store((*cluster.Listener)(nil))
|
||||
|
||||
err = c.adjustForSealMigration(conf.UnwrapSeal)
|
||||
quotasLogger := conf.Logger.Named("quotas")
|
||||
c.allLoggers = append(c.allLoggers, quotasLogger)
|
||||
c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1822,7 +1828,10 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, performCleanup
|
|||
}
|
||||
}
|
||||
|
||||
postSealInternal(c)
|
||||
if err := postSealInternal(c); err != nil {
|
||||
c.logger.Error("post seal error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("vault is sealed")
|
||||
|
||||
|
@ -1892,6 +1901,9 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
|
|||
if err := c.setupCredentials(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.setupQuotas(ctx, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if !c.IsDRSecondary() {
|
||||
if err := c.startRollback(); err != nil {
|
||||
return err
|
||||
|
@ -2078,6 +2090,10 @@ func enterpriseSetupFilteredPathsImpl(c *Core) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func enterpriseSetupQuotasImpl(ctx context.Context, c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startReplicationImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -2474,3 +2490,29 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) {
|
|||
}
|
||||
return &flags, nil
|
||||
}
|
||||
|
||||
// MatchingMount returns the path of the mount that will be responsible for
|
||||
// handling the given request path.
|
||||
func (c *Core) MatchingMount(ctx context.Context, reqPath string) string {
|
||||
return c.router.MatchingMount(ctx, reqPath)
|
||||
}
|
||||
|
||||
func (c *Core) setupQuotas(ctx context.Context, isPerfStandby bool) error {
|
||||
if c.quotaManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.quotaManager.Setup(ctx, c.systemBarrierView, isPerfStandby)
|
||||
}
|
||||
|
||||
// ApplyRateLimitQuota checks the request against all the applicable quota rules
|
||||
func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error) {
|
||||
req.Type = quotas.TypeRateLimit
|
||||
return c.quotaManager.ApplyQuota(req)
|
||||
}
|
||||
|
||||
// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit
|
||||
// logging of request rejections due to rate limiting quota rule violations.
|
||||
func (c *Core) RateLimitAuditLoggingEnabled() bool {
|
||||
return c.quotaManager.RateLimitAuditLoggingEnabled()
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/license"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
)
|
||||
|
||||
|
@ -58,7 +59,7 @@ func addExtraCredentialBackends(*Core, map[string]logical.Factory) {}
|
|||
|
||||
func preUnsealInternal(context.Context, *Core) error { return nil }
|
||||
|
||||
func postSealInternal(*Core) {}
|
||||
func postSealInternal(*Core) error { return nil }
|
||||
|
||||
func preSealPhysical(c *Core) {
|
||||
switch c.sealUnwrapper.(type) {
|
||||
|
@ -132,3 +133,23 @@ func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, chan struct{},
|
|||
func (c *Core) initSealsForMigration() {}
|
||||
|
||||
func (c *Core) postSealMigration(ctx context.Context) error { return nil }
|
||||
|
||||
func (c *Core) applyLeaseCountQuota(in *quotas.Request) (*quotas.Response, error) {
|
||||
return "as.Response{Allowed: true}, nil
|
||||
}
|
||||
|
||||
func (c *Core) ackLeaseQuota(access quotas.Access, leaseGenerated bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quotas.Request) bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) namespaceByPath(path string) *namespace.Namespace {
|
||||
return namespace.RootNamespace
|
||||
}
|
||||
|
|
|
@ -12,17 +12,18 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/locksutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
uberAtomic "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
|
@ -256,23 +257,60 @@ func (m *ExpirationManager) inRestoreMode() bool {
|
|||
}
|
||||
|
||||
func (m *ExpirationManager) invalidate(key string) {
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(key, leaseViewPrefix):
|
||||
// Clear from the pending expiration
|
||||
leaseID := strings.TrimPrefix(key, leaseViewPrefix)
|
||||
m.pendingLock.Lock()
|
||||
if info, ok := m.pending.Load(leaseID); ok {
|
||||
pending := info.(pendingInfo)
|
||||
pending.timer.Stop()
|
||||
m.pending.Delete(leaseID)
|
||||
m.leaseCount--
|
||||
ctx := m.quitContext
|
||||
_, nsID := namespace.SplitIDFromString(leaseID)
|
||||
leaseNS := namespace.RootNamespace
|
||||
var err error
|
||||
if nsID != "" {
|
||||
leaseNS, err = NamespaceByID(ctx, nsID, m.core)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to invalidate lease entry", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If in the nonexpiring map, remove there.
|
||||
m.nonexpiring.Delete(leaseID)
|
||||
le, err := m.loadEntryInternal(namespace.ContextWithNamespace(ctx, leaseNS), leaseID, false, false)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to invalidate lease entry", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.pendingLock.Unlock()
|
||||
m.pendingLock.Lock()
|
||||
defer m.pendingLock.Unlock()
|
||||
info, ok := m.pending.Load(leaseID)
|
||||
switch {
|
||||
case ok:
|
||||
switch {
|
||||
case le == nil:
|
||||
// Handle lease deletion
|
||||
pending := info.(pendingInfo)
|
||||
pending.timer.Stop()
|
||||
m.pending.Delete(leaseID)
|
||||
m.leaseCount--
|
||||
|
||||
// If in the nonexpiring map, remove there.
|
||||
m.nonexpiring.Delete(leaseID)
|
||||
|
||||
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
|
||||
m.logger.Error("failed to handle lease delete invalidation", "error", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Handle lease update
|
||||
m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now()))
|
||||
}
|
||||
default:
|
||||
// There is no entry in the pending map and the invalidation
|
||||
// resulted in a nil entry. This should ideally never happen.
|
||||
if le == nil {
|
||||
return
|
||||
}
|
||||
// Handle lease creation
|
||||
m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -692,13 +730,18 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo
|
|||
}
|
||||
}
|
||||
|
||||
// Clear the expiration handler (or remove from the list of non-expiring tokens.)
|
||||
// Clear the expiration handler
|
||||
m.pendingLock.Lock()
|
||||
if info, ok := m.pending.Load(leaseID); ok {
|
||||
pending := info.(pendingInfo)
|
||||
pending.timer.Stop()
|
||||
m.pending.Delete(leaseID)
|
||||
m.leaseCount--
|
||||
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
|
||||
m.pendingLock.Unlock()
|
||||
m.logger.Error("failed to handle lease path deletion", "error", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.nonexpiring.Delete(leaseID)
|
||||
m.pendingLock.Unlock()
|
||||
|
@ -1420,10 +1463,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim
|
|||
info.(pendingInfo).timer.Stop()
|
||||
m.pending.Delete(le.LeaseID)
|
||||
m.leaseCount--
|
||||
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil {
|
||||
m.logger.Error("failed to handle lease path deletion", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
leaseCreated := false
|
||||
// Create entry if it does not exist or reset if it does
|
||||
if ok {
|
||||
pending = info.(pendingInfo)
|
||||
|
@ -1439,12 +1487,20 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim
|
|||
}
|
||||
// new lease
|
||||
m.leaseCount++
|
||||
leaseCreated = true
|
||||
}
|
||||
|
||||
// Retain some information in-memory
|
||||
pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le)
|
||||
|
||||
m.pending.Store(le.LeaseID, pending)
|
||||
|
||||
if leaseCreated {
|
||||
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil {
|
||||
m.logger.Error("failed to handle lease creation", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// revokeEntry is used to attempt revocation of an internal entry
|
||||
|
|
|
@ -0,0 +1,384 @@
|
|||
package quotas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/builtin/logical/pki"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
testLookupOnlyPolicy = `
|
||||
path "/auth/token/lookup" {
|
||||
capabilities = [ "create", "update"]
|
||||
}
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
coreConfig = &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"pki": pki.Factory,
|
||||
},
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func setupMounts(t *testing.T, client *api.Client) {
|
||||
t.Helper()
|
||||
|
||||
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
||||
Type: "userpass",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = client.Sys().Mount("pki", &api.MountInput{
|
||||
Type: "pki",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
||||
"common_name": "testvault.com",
|
||||
"ttl": "200h",
|
||||
"ip_sans": "127.0.0.1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
||||
"require_cn": false,
|
||||
"allowed_domains": "testvault.com",
|
||||
"allow_subdomains": true,
|
||||
"max_ttl": "2h",
|
||||
"generate_lease": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func teardownMounts(t *testing.T, client *api.Client) {
|
||||
t.Helper()
|
||||
if err := client.Sys().Unmount("pki"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := client.Sys().DisableAuth("userpass"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) {
|
||||
numSuccess := atomic.NewInt32(0)
|
||||
numFail := atomic.NewInt32(0)
|
||||
|
||||
start := time.Now()
|
||||
end := start.Add(d)
|
||||
for time.Now().Before(end) {
|
||||
reqFunc(numSuccess, numFail)
|
||||
}
|
||||
|
||||
return numSuccess.Load(), numFail.Load(), time.Since(start)
|
||||
}
|
||||
|
||||
func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error {
|
||||
ticker := time.Tick(tick)
|
||||
timeout := time.After(to)
|
||||
|
||||
// wait for the resource to be removed
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path)
|
||||
|
||||
case <-ticker:
|
||||
resp, err := c.Logical().Read(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuotas_RateLimitQuota_Mount(t *testing.T) {
|
||||
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
||||
cluster := vault.NewTestCluster(t, conf, opts)
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0].Core
|
||||
client := cluster.Cores[0].Client
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
err := client.Sys().Mount("pki", &api.MountInput{
|
||||
Type: "pki",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
||||
"common_name": "testvault.com",
|
||||
"ttl": "200h",
|
||||
"ip_sans": "127.0.0.1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
||||
"require_cn": false,
|
||||
"allowed_domains": "testvault.com",
|
||||
"allow_subdomains": true,
|
||||
"max_ttl": "2h",
|
||||
"generate_lease": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
||||
_, err := client.Logical().Read("pki/cert/ca_chain")
|
||||
|
||||
if err != nil {
|
||||
numFail.Add(1)
|
||||
} else {
|
||||
numSuccess.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a rate limit quota with a low RPS of 7.7, which means we can process
|
||||
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
|
||||
// by a refill rate of 7.7 per-second.
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
||||
"rate": 7.7,
|
||||
"burst": 8,
|
||||
"path": "pki/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
||||
|
||||
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
||||
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
||||
|
||||
// ensure there were some failed requests
|
||||
if numFail == 0 {
|
||||
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
||||
}
|
||||
|
||||
// ensure that we should never get more requests than allowed
|
||||
if want := int32(ideal + 1); numSuccess > want {
|
||||
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
||||
}
|
||||
|
||||
// update the rate limit quota with a high RPS such that no requests should fail
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
||||
"rate": 1000.0,
|
||||
"burst": 3000,
|
||||
"path": "pki/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
|
||||
if numFail > 0 {
|
||||
t.Fatalf("unexpected number of failed requests: %d", numFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) {
|
||||
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
||||
cluster := vault.NewTestCluster(t, conf, opts)
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0].Core
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
// create PKI mount
|
||||
err := client.Sys().Mount("pki", &api.MountInput{
|
||||
Type: "pki",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
|
||||
"common_name": "testvault.com",
|
||||
"ttl": "200h",
|
||||
"ip_sans": "127.0.0.1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
|
||||
"require_cn": false,
|
||||
"allowed_domains": "testvault.com",
|
||||
"allow_subdomains": true,
|
||||
"max_ttl": "2h",
|
||||
"generate_lease": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a root rate limit quota
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{
|
||||
"name": "root-rlq",
|
||||
"rate": 14.7,
|
||||
"burst": 15,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a mount rate limit quota with a lower RPS than the root rate limit quota
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{
|
||||
"name": "mount-rlq",
|
||||
"rate": 7.7,
|
||||
"burst": 8,
|
||||
"path": "pki/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure mount rate limit quota takes precedence over root rate limit quota
|
||||
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
||||
_, err := client.Logical().Read("pki/cert/ca_chain")
|
||||
|
||||
if err != nil {
|
||||
numFail.Add(1)
|
||||
} else {
|
||||
numSuccess.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ensure mount rate limit quota takes precedence over root rate limit quota
|
||||
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
||||
|
||||
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
||||
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
||||
|
||||
// ensure there were some failed requests
|
||||
if numFail == 0 {
|
||||
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
||||
}
|
||||
|
||||
// ensure that we should never get more requests than allowed
|
||||
if want := int32(ideal + 1); numSuccess > want {
|
||||
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuotas_RateLimitQuota(t *testing.T) {
|
||||
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
|
||||
cluster := vault.NewTestCluster(t, conf, opts)
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0].Core
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
||||
Type: "userpass",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a rate limit quota with a low RPS of 7.7, which means we can process
|
||||
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
|
||||
// by a refill rate of 7.7 per-second.
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
||||
"rate": 7.7,
|
||||
"burst": 8,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqFunc := func(numSuccess, numFail *atomic.Int32) {
|
||||
_, err := client.Logical().Read("sys/quotas/rate-limit/rlq")
|
||||
|
||||
if err != nil {
|
||||
numFail.Add(1)
|
||||
} else {
|
||||
numSuccess.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
|
||||
|
||||
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
||||
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
|
||||
|
||||
// ensure there were some failed requests
|
||||
if numFail == 0 {
|
||||
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
|
||||
}
|
||||
|
||||
// ensure that we should never get more requests than allowed
|
||||
if want := int32(ideal + 1); numSuccess > want {
|
||||
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
|
||||
}
|
||||
|
||||
// allow time (1s) for rate limit to refill before updating the quota
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// update the rate limit quota with a high RPS such that no requests should fail
|
||||
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
|
||||
"rate": 1000.0,
|
||||
"burst": 3000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
|
||||
if numFail > 0 {
|
||||
t.Fatalf("unexpected number of failed requests: %d", numFail)
|
||||
}
|
||||
}
|
|
@ -160,6 +160,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
|
|||
b.Backend.Paths = append(b.Backend.Paths, b.metricsPath())
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.monitorPath())
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath())
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.quotasPaths()...)
|
||||
|
||||
if core.rawEnabled {
|
||||
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
|
||||
|
@ -751,7 +752,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d
|
|||
|
||||
// Get all the options
|
||||
path := data.Get("path").(string)
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
logicalType := data.Get("type").(string)
|
||||
description := data.Get("description").(string)
|
||||
|
@ -934,7 +935,7 @@ func handleErrorNoReadOnlyForward(
|
|||
// handleUnmount is used to unmount a path
|
||||
func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -1029,6 +1030,12 @@ func (b *SystemBackend) handleRemount(ctx context.Context, req *logical.Request,
|
|||
return handleError(err)
|
||||
}
|
||||
|
||||
// Update quotas with the new path
|
||||
if err := b.Core.quotaManager.HandleRemount(ctx, ns.Path, sanitizePath(fromPath), sanitizePath(toPath)); err != nil {
|
||||
b.Core.logger.Error("failed to update quotas after remount", "ns_path", ns.Path, "from_path", fromPath, "to_path", toPath, "error", err)
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -1060,7 +1067,7 @@ func (b *SystemBackend) handleMountTuneRead(ctx context.Context, req *logical.Re
|
|||
|
||||
// handleTuneReadCommon returns the config settings of a path
|
||||
func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) {
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
sysView := b.Core.router.MatchingSystemView(ctx, path)
|
||||
if sysView == nil {
|
||||
|
@ -1146,7 +1153,7 @@ func (b *SystemBackend) handleMountTuneWrite(ctx context.Context, req *logical.R
|
|||
func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) {
|
||||
repState := b.Core.ReplicationState()
|
||||
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
// Prevent protected paths from being changed
|
||||
for _, p := range untunableMounts {
|
||||
|
@ -1716,7 +1723,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
|
|||
|
||||
// Get all the options
|
||||
path := data.Get("path").(string)
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
logicalType := data.Get("type").(string)
|
||||
description := data.Get("description").(string)
|
||||
pluginName := data.Get("plugin_name").(string)
|
||||
|
@ -1857,7 +1864,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
|
|||
// handleDisableAuth is used to disable a credential backend
|
||||
func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
path := data.Get("path").(string)
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -2272,7 +2279,7 @@ func (b *SystemBackend) handleAuditHash(ctx context.Context, req *logical.Reques
|
|||
return logical.ErrorResponse("the \"input\" parameter is empty"), nil
|
||||
}
|
||||
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
hash, err := b.Core.auditBroker.GetHash(ctx, path, input)
|
||||
if err != nil {
|
||||
|
@ -3258,7 +3265,7 @@ func (b *SystemBackend) pathInternalUIMountRead(ctx context.Context, req *logica
|
|||
if path == "" {
|
||||
return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest
|
||||
}
|
||||
path = sanitizeMountPath(path)
|
||||
path = sanitizePath(path)
|
||||
|
||||
errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path))
|
||||
|
||||
|
@ -3576,7 +3583,7 @@ func (b *SystemBackend) pathInternalOpenAPI(ctx context.Context, req *logical.Re
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func sanitizeMountPath(path string) string {
|
||||
func sanitizePath(path string) string {
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
)
|
||||
|
||||
// quotasPaths returns paths that enable quota management
|
||||
func (b *SystemBackend) quotasPaths() []*framework.Path {
|
||||
return []*framework.Path{
|
||||
{
|
||||
Pattern: "quotas/config$",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"enable_rate_limit_audit_logging": {
|
||||
Type: framework.TypeBool,
|
||||
Description: "If set, starts audit logging of requests that get rejected due to rate limit quota rule violations.",
|
||||
},
|
||||
},
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: b.handleQuotasConfigUpdate(),
|
||||
},
|
||||
logical.ReadOperation: &framework.PathOperation{
|
||||
Callback: b.handleQuotasConfigRead(),
|
||||
},
|
||||
},
|
||||
HelpSynopsis: strings.TrimSpace(quotasHelp["quotas-config"][0]),
|
||||
HelpDescription: strings.TrimSpace(quotasHelp["quotas-config"][1]),
|
||||
},
|
||||
{
|
||||
Pattern: "quotas/rate-limit/?$",
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.ListOperation: &framework.PathOperation{
|
||||
Callback: b.handleRateLimitQuotasList(),
|
||||
},
|
||||
},
|
||||
HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit-list"][0]),
|
||||
HelpDescription: strings.TrimSpace(quotasHelp["rate-limit-list"][1]),
|
||||
},
|
||||
{
|
||||
Pattern: "quotas/rate-limit/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"type": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Type of the quota rule.",
|
||||
},
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the quota rule.",
|
||||
},
|
||||
"path": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Path of the mount or namespace to apply the quota. A blank path configures a
|
||||
global quota. For example namespace1/ adds a quota to a full namespace,
|
||||
namespace1/auth/userpass adds a quota to userpass in namespace1.`,
|
||||
},
|
||||
"rate": {
|
||||
Type: framework.TypeFloat,
|
||||
Description: `The rate at which allowed requests are refilled per second by the quota rule.
|
||||
Internally, a token-bucket algorithm is used which has a size of 'burst', initially full. The quota
|
||||
limits requests to 'rate' per-second, with a maximum burst size of 'burst'. Each request takes a single
|
||||
token from this bucket. The 'rate' must be positive.`,
|
||||
},
|
||||
"burst": {
|
||||
Type: framework.TypeInt,
|
||||
Description: `The maximum number of requests at any given second to be allowed by the quota
|
||||
rule. There is a one-to-one mapping between requests and tokens in the rate limit quota. A client
|
||||
may perform up to 'burst' requests at once, at which they they may invoke additional requests at
|
||||
'rate' per-second.`,
|
||||
},
|
||||
},
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: b.handleRateLimitQuotasUpdate(),
|
||||
},
|
||||
logical.ReadOperation: &framework.PathOperation{
|
||||
Callback: b.handleRateLimitQuotasRead(),
|
||||
},
|
||||
logical.DeleteOperation: &framework.PathOperation{
|
||||
Callback: b.handleRateLimitQuotasDelete(),
|
||||
},
|
||||
},
|
||||
HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit"][0]),
|
||||
HelpDescription: strings.TrimSpace(quotasHelp["rate-limit"][1]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleQuotasConfigUpdate() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
config, err := quotas.LoadConfig(ctx, b.Core.systemBarrierView)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.EnableRateLimitAuditLogging = d.Get("enable_rate_limit_audit_logging").(bool)
|
||||
|
||||
entry, err := logical.StorageEntryJSON(quotas.ConfigPath, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.Core.quotaManager.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleQuotasConfigRead() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
config := b.Core.quotaManager.Config()
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"enable_rate_limit_audit_logging": config.EnableRateLimitAuditLogging,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleRateLimitQuotasList() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
names, err := b.Core.quotaManager.QuotaNames(quotas.TypeRateLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(names), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
qType := quotas.TypeRateLimit.String()
|
||||
rate := d.Get("rate").(float64)
|
||||
if rate <= 0 {
|
||||
return logical.ErrorResponse("'rate' is invalid"), nil
|
||||
}
|
||||
|
||||
burst := d.Get("burst").(int)
|
||||
if burst < int(rate) {
|
||||
return logical.ErrorResponse("'burst' must be greater than or equal to 'rate' as an integer value"), nil
|
||||
}
|
||||
|
||||
mountPath := sanitizePath(d.Get("path").(string))
|
||||
ns := b.Core.namespaceByPath(mountPath)
|
||||
if ns.ID != namespace.RootNamespaceID {
|
||||
mountPath = strings.TrimPrefix(mountPath, ns.Path)
|
||||
}
|
||||
|
||||
if mountPath != "" {
|
||||
match := b.Core.router.MatchingMount(namespace.ContextWithNamespace(ctx, ns), mountPath)
|
||||
if match == "" {
|
||||
return logical.ErrorResponse("invalid mount path %q", mountPath), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Disallow duplicate quotas with same precedence and similar
|
||||
// properties.
|
||||
quota, err := b.Core.quotaManager.QuotaByFactors(ctx, qType, ns.Path, mountPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quota != nil && quota.QuotaName() != name {
|
||||
return logical.ErrorResponse("quota rule with similar properties exists under the name %q", quota.QuotaName()), nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case quota == nil:
|
||||
quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, burst)
|
||||
default:
|
||||
rlq := quota.(*quotas.RateLimitQuota)
|
||||
rlq.NamespacePath = ns.Path
|
||||
rlq.MountPath = mountPath
|
||||
rlq.Rate = rate
|
||||
rlq.Burst = burst
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON(quotas.QuotaStoragePath(qType, name), quota)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := b.Core.quotaManager.SetQuota(ctx, qType, quota, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleRateLimitQuotasRead() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
qType := quotas.TypeRateLimit.String()
|
||||
|
||||
quota, err := b.Core.quotaManager.QuotaByName(qType, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quota == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rlq := quota.(*quotas.RateLimitQuota)
|
||||
|
||||
nsPath := rlq.NamespacePath
|
||||
if rlq.NamespacePath == "root" {
|
||||
nsPath = ""
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"type": qType,
|
||||
"name": rlq.Name,
|
||||
"path": nsPath + rlq.MountPath,
|
||||
"rate": rlq.Rate,
|
||||
"burst": rlq.Burst,
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleRateLimitQuotasDelete() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
qType := quotas.TypeRateLimit.String()
|
||||
|
||||
if err := req.Storage.Delete(ctx, quotas.QuotaStoragePath(qType, name)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := b.Core.quotaManager.DeleteQuota(ctx, qType, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
var quotasHelp = map[string][2]string{
|
||||
"quotas-config": {
|
||||
"Create, update and read the quota configuration.",
|
||||
"",
|
||||
},
|
||||
"rate-limit": {
|
||||
`Get, create or update rate limit resource quota for an optional namespace or
|
||||
mount.`,
|
||||
`A rate limit quota will enforce rate limiting using a token bucket algorithm. A
|
||||
rate limit quota can be created at the root level or defined on a namespace or
|
||||
mount by specifying a 'path'. The rate limiter is applied to each unique client
|
||||
IP address. A client may invoke 'burst' requests at any given second, after
|
||||
which they may invoke additional requests at 'rate' per-second.`,
|
||||
},
|
||||
"rate-limit-list": {
|
||||
"Lists the names of all the rate limit quotas.",
|
||||
"This list contains quota definitions from all the namespaces.",
|
||||
},
|
||||
}
|
|
@ -2654,7 +2654,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) {
|
|||
// Add another mount
|
||||
me := &MountEntry{
|
||||
Table: mountTableType,
|
||||
Path: sanitizeMountPath("kv-v1"),
|
||||
Path: sanitizePath("kv-v1"),
|
||||
Type: "kv",
|
||||
Options: map[string]string{"version": "1"},
|
||||
}
|
||||
|
|
|
@ -664,6 +664,11 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b
|
|||
|
||||
removePathCheckers(c, entry, viewPath)
|
||||
|
||||
if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil {
|
||||
c.logger.Error("failed to update quotas after disabling mount", "path", path, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,860 @@
|
|||
package quotas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// Type represents the quota kind
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// TypeRateLimit represents the rate limiting quota type
|
||||
TypeRateLimit Type = "rate-limit"
|
||||
|
||||
// TypeLeaseCount represents the lease count limiting quota type
|
||||
TypeLeaseCount Type = "lease-count"
|
||||
)
|
||||
|
||||
// LeaseAction is the action taken by the expiration manager on the lease. The
|
||||
// quota manager will use this information to update the lease path cache and
|
||||
// updating counters for relevant quota rules.
|
||||
type LeaseAction uint32
|
||||
|
||||
// String converts each lease action into its string equivalent value
|
||||
func (la LeaseAction) String() string {
|
||||
switch la {
|
||||
case LeaseActionLoaded:
|
||||
return "loaded"
|
||||
case LeaseActionCreated:
|
||||
return "created"
|
||||
case LeaseActionDeleted:
|
||||
return "deleted"
|
||||
case LeaseActionAllow:
|
||||
return "allow"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
const (
|
||||
_ LeaseAction = iota
|
||||
|
||||
// LeaseActionLoaded indicates loading of lease in the expiration manager after
|
||||
// unseal.
|
||||
LeaseActionLoaded
|
||||
|
||||
// LeaseActionCreated indicates that a lease is created in the expiration manager.
|
||||
LeaseActionCreated
|
||||
|
||||
// LeaseActionDeleted indicates that is lease is expired and deleted in the
|
||||
// expiration manager.
|
||||
LeaseActionDeleted
|
||||
|
||||
// LeaseActionAllow will be used to indicate the lease count checker that
|
||||
// incCounter is called from Allow(). All the rest of the actions indicate the
|
||||
// action took place on the lease in the expiration manager.
|
||||
LeaseActionAllow
|
||||
)
|
||||
|
||||
type leaseWalkFunc func(context.Context, func(request *Request) bool) error
|
||||
|
||||
// String converts each quota type into its string equivalent value
|
||||
func (q Type) String() string {
|
||||
switch q {
|
||||
case TypeLeaseCount:
|
||||
return "lease-count"
|
||||
case TypeRateLimit:
|
||||
return "rate-limit"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
const (
|
||||
indexID = "id"
|
||||
indexName = "name"
|
||||
indexNamespace = "ns"
|
||||
indexNamespaceMount = "ns_mount"
|
||||
)
|
||||
|
||||
const (
|
||||
// StoragePrefix is the prefix for the physical location where quota rules are
|
||||
// persisted.
|
||||
StoragePrefix = "quotas/"
|
||||
|
||||
// ConfigPath is the physical location where the quota configuration is
|
||||
// persisted.
|
||||
ConfigPath = StoragePrefix + "config"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
|
||||
// count quota being exceeded.
|
||||
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
|
||||
|
||||
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
|
||||
// rate limit quota being exceeded.
|
||||
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
|
||||
)
|
||||
|
||||
// Access provides information to reach back to the quota checker.
|
||||
type Access interface {
|
||||
// QuotaID is the identifier of the quota that issued this access.
|
||||
QuotaID() string
|
||||
}
|
||||
|
||||
// Ensure that access implements the Access interface.
|
||||
var _ Access = (*access)(nil)
|
||||
|
||||
// access implements the Access interface
|
||||
type access struct {
|
||||
quotaID string
|
||||
}
|
||||
|
||||
// QuotaID returns the identifier of the quota rule to which this access refers
|
||||
// to.
|
||||
func (a *access) QuotaID() string {
|
||||
return a.quotaID
|
||||
}
|
||||
|
||||
// Manager holds all the existing quota rules. For any given input. the manager
|
||||
// checks them against any applicable quota rules.
|
||||
type Manager struct {
|
||||
entManager
|
||||
|
||||
// db holds the in memory instances of all active quota rules indexed by
|
||||
// some of the quota properties.
|
||||
db *memdb.MemDB
|
||||
|
||||
// config containing operator preferences and quota behaviors
|
||||
config *Config
|
||||
|
||||
storage logical.Storage
|
||||
ctx context.Context
|
||||
|
||||
logger log.Logger
|
||||
metricSink *metricsutil.ClusterMetricSink
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
// Quota represents the common properties of every quota type
|
||||
type Quota interface {
|
||||
// allow checks the if the request is allowed by the quota type implementation.
|
||||
allow(*Request) (Response, error)
|
||||
|
||||
// quotaID is the identifier of the quota rule
|
||||
quotaID() string
|
||||
|
||||
// QuotaName is the name of the quota rule
|
||||
QuotaName() string
|
||||
|
||||
// initialize sets up the fields in the quota type to begin operating
|
||||
initialize(log.Logger, *metricsutil.ClusterMetricSink) error
|
||||
|
||||
// close defines any cleanup behavior that needs to be executed when a quota
|
||||
// rule is deleted.
|
||||
close() error
|
||||
|
||||
// handleRemount takes in the new mount path in the quota
|
||||
handleRemount(string)
|
||||
}
|
||||
|
||||
// Response holds information about the result of the Allow() call. The response
|
||||
// can optionally have the Access field set, which is used to reach back into
|
||||
// the quota rule that sent this response.
|
||||
type Response struct {
|
||||
// Allowed is set if the quota allows the request
|
||||
Allowed bool
|
||||
|
||||
// Access is the handle to reach back into the quota rule that processed the
|
||||
// quota request. This may not be set all the time.
|
||||
Access Access
|
||||
}
|
||||
|
||||
// Config holds operator preferences around quota behaviors
|
||||
type Config struct {
|
||||
// EnableRateLimitAuditLogging, if set, starts audit logging of the
|
||||
// request rejections that arise due to rate limit quota violations.
|
||||
EnableRateLimitAuditLogging bool `json:"enable_rate_limit_audit_logging"`
|
||||
}
|
||||
|
||||
// Request contains information required by the quota manager to query and
|
||||
// apply the quota rules.
|
||||
type Request struct {
|
||||
// Type is the quota type
|
||||
Type Type
|
||||
|
||||
// Path is the request path to which quota rules are being queried for
|
||||
Path string
|
||||
|
||||
// NamespacePath is the namespace path to which the request belongs
|
||||
NamespacePath string
|
||||
|
||||
// MountPath is the mount path to which the request is made
|
||||
MountPath string
|
||||
|
||||
// ClientAddress is client unique addressable string (e.g. IP address). It can
|
||||
// be empty if the quota type does not need it.
|
||||
ClientAddress string
|
||||
}
|
||||
|
||||
// NewManager creates and initializes a new quota manager to hold all the quota
|
||||
// rules and to process incoming requests.
|
||||
func NewManager(logger log.Logger, walkFunc leaseWalkFunc, ms *metricsutil.ClusterMetricSink) (*Manager, error) {
|
||||
db, err := memdb.NewMemDB(dbSchema())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager := &Manager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
metricSink: ms,
|
||||
config: new(Config),
|
||||
lock: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
manager.init(walkFunc)
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// SetQuota adds a new quota rule to the db.
|
||||
func (m *Manager) SetQuota(ctx context.Context, qType string, quota Quota, loading bool) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
return m.setQuotaLocked(ctx, qType, quota, loading)
|
||||
}
|
||||
|
||||
// setQuotaLocked should be called with the manager's lock held
|
||||
func (m *Manager) setQuotaLocked(ctx context.Context, qType string, quota Quota, loading bool) error {
|
||||
if qType == TypeLeaseCount.String() {
|
||||
m.setIsPerfStandby(quota)
|
||||
}
|
||||
|
||||
txn := m.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
raw, err := txn.First(qType, "id", quota.quotaID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there already exists an entry in the db, remove that first.
|
||||
if raw != nil {
|
||||
err = txn.Delete(qType, raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the quota type implementation
|
||||
if err := quota.initialize(m.logger, m.metricSink); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the initialized quota type implementation to the db
|
||||
if err := txn.Insert(qType, quota); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if loading {
|
||||
txn.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// For the lease count type, recompute the counters
|
||||
if !loading && qType == TypeLeaseCount.String() {
|
||||
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuotaNames returns the names of all the quota rules for a given type
|
||||
func (m *Manager) QuotaNames(qType Type) ([]string, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
txn := m.db.Txn(false)
|
||||
iter, err := txn.Get(qType.String(), indexID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var names []string
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
names = append(names, raw.(Quota).QuotaName())
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// QuotaByID queries for a quota rule in the db for a given quota ID
|
||||
func (m *Manager) QuotaByID(qType string, id string) (Quota, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
txn := m.db.Txn(false)
|
||||
|
||||
quotaRaw, err := txn.First(qType, indexID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quotaRaw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return quotaRaw.(Quota), nil
|
||||
}
|
||||
|
||||
// QuotaByName queries for a quota rule in the db for a given quota name
|
||||
func (m *Manager) QuotaByName(qType string, name string) (Quota, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
txn := m.db.Txn(false)
|
||||
|
||||
quotaRaw, err := txn.First(qType, indexName, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quotaRaw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return quotaRaw.(Quota), nil
|
||||
}
|
||||
|
||||
// QuotaByFactors returns the quota rule that matches the provided factors
|
||||
func (m *Manager) QuotaByFactors(ctx context.Context, qType, nsPath, mountPath string) (Quota, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
// nsPath would have been made non-empty during insertion. Use non-empty value
|
||||
// during query as well.
|
||||
if nsPath == "" {
|
||||
nsPath = "root"
|
||||
}
|
||||
|
||||
idx := indexNamespace
|
||||
args := []interface{}{nsPath, false}
|
||||
if mountPath != "" {
|
||||
idx = indexNamespaceMount
|
||||
args = []interface{}{nsPath, mountPath}
|
||||
}
|
||||
|
||||
txn := m.db.Txn(false)
|
||||
iter, err := txn.Get(qType, idx, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var quotas []Quota
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
quotas = append(quotas, raw.(Quota))
|
||||
}
|
||||
if len(quotas) > 1 {
|
||||
return nil, fmt.Errorf("conflicting quota definitions detected")
|
||||
}
|
||||
if len(quotas) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return quotas[0], nil
|
||||
}
|
||||
|
||||
// queryQuota returns the quota rule that is applicable for the given request. It
|
||||
// queries all the quota rules that are defined against request values and finds
|
||||
// the quota rule that takes priority.
|
||||
//
|
||||
// Priority rules are as follows:
|
||||
// - namespace specific quota takes precedence over global quota
|
||||
// - mount specific quota takes precedence over namespace specific quota
|
||||
func (m *Manager) queryQuota(txn *memdb.Txn, req *Request) (Quota, error) {
|
||||
if txn == nil {
|
||||
txn = m.db.Txn(false)
|
||||
}
|
||||
|
||||
// ns would have been made non-empty during insertion. Use non-empty
|
||||
// value during query as well.
|
||||
if req.NamespacePath == "" {
|
||||
req.NamespacePath = "root"
|
||||
}
|
||||
|
||||
//
|
||||
// Find a match from most specific applicable quota rule to less specific one.
|
||||
//
|
||||
quotaFetchFunc := func(idx string, args ...interface{}) (Quota, error) {
|
||||
iter, err := txn.Get(req.Type.String(), idx, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var quotas []Quota
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
quota := raw.(Quota)
|
||||
quotas = append(quotas, quota)
|
||||
}
|
||||
if len(quotas) > 1 {
|
||||
return nil, fmt.Errorf("conflicting quota definitions detected")
|
||||
}
|
||||
if len(quotas) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return quotas[0], nil
|
||||
}
|
||||
|
||||
// Fetch mount quota
|
||||
quota, err := quotaFetchFunc(indexNamespaceMount, req.NamespacePath, req.MountPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quota != nil {
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
// Fetch ns quota. If NamespacePath is root, this will return the global quota.
|
||||
quota, err = quotaFetchFunc(indexNamespace, req.NamespacePath, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quota != nil {
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
// If the request belongs to "root" namespace, then we have already looked at
|
||||
// global quotas when fetching namespace specific quota rule. When the request
|
||||
// belongs to a non-root namespace, and when there are no namespace specific
|
||||
// quota rules present, we fallback on the global quotas.
|
||||
if req.NamespacePath == "root" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fetch global quota
|
||||
quota, err = quotaFetchFunc(indexNamespace, "root", false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if quota != nil {
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// DeleteQuota removes a quota rule from the db for a given name
|
||||
func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
txn := m.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
raw, err := txn.First(qType, indexName, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
quota := raw.(Quota)
|
||||
if err := quota.close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = txn.Delete(qType, raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For the lease count type, recompute the counters
|
||||
if qType == TypeLeaseCount.String() {
|
||||
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyQuota runs the request against any quota rule that is applicable to it. If
|
||||
// there are multiple quota rule that matches the request parameters, rule that
|
||||
// takes precedence will be used to allow/reject the request.
|
||||
func (m *Manager) ApplyQuota(req *Request) (Response, error) {
|
||||
var resp Response
|
||||
|
||||
quota, err := m.queryQuota(nil, req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// If there is no quota defined, allow the request.
|
||||
if quota == nil {
|
||||
resp.Allowed = true
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// If the quota type is lease count, and if the path is not known to
|
||||
// generate leases, allow the request.
|
||||
if req.Type == TypeLeaseCount && !m.inLeasePathCache(req.Path) {
|
||||
resp.Allowed = true
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return quota.allow(req)
|
||||
}
|
||||
|
||||
// SetEnableRateLimitAuditLogging updates the operator preference regarding the
|
||||
// audit logging behavior.
|
||||
func (m *Manager) SetEnableRateLimitAuditLogging(val bool) {
|
||||
m.config.EnableRateLimitAuditLogging = val
|
||||
}
|
||||
|
||||
// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit
|
||||
// logging of request rejections due to rate limiting quota rule violations.
|
||||
func (m *Manager) RateLimitAuditLoggingEnabled() bool {
|
||||
return m.config.EnableRateLimitAuditLogging
|
||||
}
|
||||
|
||||
// Config returns the operator preferences in the quota manager
|
||||
func (m *Manager) Config() *Config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Reset will clear all the quotas from the db and clear the lease path cache.
|
||||
func (m *Manager) Reset() error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var err error
|
||||
m.db, err = memdb.NewMemDB(dbSchema())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.storage = nil
|
||||
m.ctx = nil
|
||||
|
||||
m.entManager.Reset()
|
||||
return nil
|
||||
}
|
||||
|
||||
// dbSchema creates a DB schema for holding all the quota rules. It creates a
|
||||
// table for each supported type of quota.
|
||||
func dbSchema() *memdb.DBSchema {
|
||||
schema := &memdb.DBSchema{
|
||||
Tables: make(map[string]*memdb.TableSchema),
|
||||
}
|
||||
|
||||
commonSchema := func(name string) *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: name,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
indexID: {
|
||||
Name: indexID,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
indexName: {
|
||||
Name: indexName,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Name",
|
||||
},
|
||||
},
|
||||
indexNamespace: {
|
||||
Name: indexNamespace,
|
||||
Indexer: &memdb.CompoundMultiIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "NamespacePath",
|
||||
},
|
||||
// By sending false as the query parameter, we can
|
||||
// query just the namespace specific quota.
|
||||
&memdb.FieldSetIndex{
|
||||
Field: "MountPath",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
indexNamespaceMount: {
|
||||
Name: indexNamespaceMount,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.CompoundMultiIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "NamespacePath",
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "MountPath",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Create a table per quota type. This allows names to be reused between
|
||||
// different quota types and querying a bit easier.
|
||||
for _, name := range quotaTypes() {
|
||||
schema.Tables[name] = commonSchema(name)
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
// Invalidate receives notifications from the replication sub-system when a key
|
||||
// is updated in the storage. This function will read the key from storage and
|
||||
// updates the caches and data structures to reflect those updates.
|
||||
func (m *Manager) Invalidate(key string) {
|
||||
switch key {
|
||||
case "config":
|
||||
config, err := LoadConfig(m.ctx, m.storage)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to invalidate quota config", "error", err)
|
||||
return
|
||||
}
|
||||
m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
|
||||
default:
|
||||
splitKeys := strings.Split(key, "/")
|
||||
if len(splitKeys) != 2 {
|
||||
m.logger.Error("incorrect key while invalidating quota rule")
|
||||
return
|
||||
}
|
||||
qType := splitKeys[0]
|
||||
name := splitKeys[1]
|
||||
|
||||
// Read quota rule from storage
|
||||
quota, err := Load(m.ctx, m.storage, qType, name)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to read invalidated quota rule", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case quota == nil:
|
||||
// Handle quota deletion
|
||||
if err := m.DeleteQuota(m.ctx, qType, name); err != nil {
|
||||
m.logger.Error("failed to delete invalidated quota rule", "error", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Handle quota update
|
||||
if err := m.SetQuota(m.ctx, qType, quota, false); err != nil {
|
||||
m.logger.Error("failed to update invalidated quota rule", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfig reads the quota configuration from the underlying storage
|
||||
func LoadConfig(ctx context.Context, storage logical.Storage) (*Config, error) {
|
||||
var config Config
|
||||
entry, err := storage.Get(ctx, ConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
err = entry.DecodeJSON(&config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Load reads the quota rule from the underlying storage
|
||||
func Load(ctx context.Context, storage logical.Storage, qType, name string) (Quota, error) {
|
||||
var quota Quota
|
||||
entry, err := storage.Get(ctx, QuotaStoragePath(qType, name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch qType {
|
||||
case TypeRateLimit.String():
|
||||
quota = &RateLimitQuota{}
|
||||
case TypeLeaseCount.String():
|
||||
quota = &LeaseCountQuota{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %v", qType)
|
||||
}
|
||||
|
||||
err = entry.DecodeJSON(quota)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
// Setup loads the quota configuration and all the quota rules into the
|
||||
// quota manager.
|
||||
func (m *Manager) Setup(ctx context.Context, storage logical.Storage, isPerfStandby bool) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
m.storage = storage
|
||||
m.ctx = ctx
|
||||
m.isPerfStandby = isPerfStandby
|
||||
|
||||
// Load the quota configuration from storage and load it into the quota
|
||||
// manager.
|
||||
config, err := LoadConfig(ctx, storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
|
||||
|
||||
// Load the quota rules for all supported types from storage and load it in
|
||||
// the quota manager.
|
||||
for _, qType := range quotaTypes() {
|
||||
names, err := logical.CollectKeys(ctx, logical.NewStorageView(storage, StoragePrefix+qType+"/"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, name := range names {
|
||||
quota, err := Load(ctx, m.storage, qType, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if quota == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = m.setQuotaLocked(ctx, qType, quota, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuotaStoragePath returns the storage path suffix for persisting the quota
|
||||
// rule.
|
||||
func QuotaStoragePath(quotaType, name string) string {
|
||||
return path.Join(StoragePrefix+quotaType, name)
|
||||
}
|
||||
|
||||
// HandleRemount updates the quota subsystem about the remount operation that
|
||||
// took place. Quota manager will trigger the quota specific updates including
|
||||
// the mount path update..
|
||||
func (m *Manager) HandleRemount(ctx context.Context, nsPath, fromPath, toPath string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
txn := m.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
// nsPath would have been made non-empty during insertion. Use non-empty value
|
||||
// during query as well.
|
||||
if nsPath == "" {
|
||||
nsPath = "root"
|
||||
}
|
||||
|
||||
idx := indexNamespaceMount
|
||||
leaseQuotaUpdated := false
|
||||
args := []interface{}{nsPath, fromPath}
|
||||
for _, quotaType := range quotaTypes() {
|
||||
iter, err := txn.Get(quotaType, idx, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
quota := raw.(Quota)
|
||||
quota.handleRemount(toPath)
|
||||
entry, err := logical.StorageEntryJSON(QuotaStoragePath(quotaType, quota.QuotaName()), quota)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.storage.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
if quotaType == TypeLeaseCount.String() {
|
||||
leaseQuotaUpdated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if leaseQuotaUpdated {
|
||||
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleBackendDisabling updates the quota subsystem with the disabling of auth
|
||||
// or secret engine disabling.
|
||||
func (m *Manager) HandleBackendDisabling(ctx context.Context, nsPath, mountPath string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
txn := m.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
// nsPath would have been made non-empty during insertion. Use non-empty value
|
||||
// during query as well.
|
||||
if nsPath == "" {
|
||||
nsPath = "root"
|
||||
}
|
||||
|
||||
idx := indexNamespaceMount
|
||||
leaseQuotaDeleted := false
|
||||
args := []interface{}{nsPath, mountPath}
|
||||
for _, quotaType := range quotaTypes() {
|
||||
iter, err := txn.Get(quotaType, idx, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for raw := iter.Next(); raw != nil; raw = iter.Next() {
|
||||
if err := txn.Delete(quotaType, raw); err != nil {
|
||||
return fmt.Errorf("failed to delete quota from db after mount disabling; namespace %q, err %v", nsPath, err)
|
||||
}
|
||||
quota := raw.(Quota)
|
||||
if err := m.storage.Delete(ctx, QuotaStoragePath(quotaType, quota.QuotaName())); err != nil {
|
||||
return fmt.Errorf("failed to delete quota from storage after mount disabling; namespace %q, err %v", nsPath, err)
|
||||
}
|
||||
if quotaType == TypeLeaseCount.String() {
|
||||
leaseQuotaDeleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if leaseQuotaDeleted {
|
||||
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,282 @@
|
|||
package quotas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/pathmanager"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var rateLimitExemptPaths = pathmanager.New()
|
||||
|
||||
const (
|
||||
// DefaultRateLimitPurgeInterval defines the default purge interval used by a
|
||||
// RateLimitQuota to remove stale client rate limiters.
|
||||
DefaultRateLimitPurgeInterval = time.Minute
|
||||
|
||||
// DefaultRateLimitStaleAge defines the default stale age of a client limiter.
|
||||
DefaultRateLimitStaleAge = 3 * time.Minute
|
||||
|
||||
// EnvVaultEnableRateLimitAuditLogging is used to enable audit logging of
|
||||
// requests that get rejected due to rate limit quota violations.
|
||||
EnvVaultEnableRateLimitAuditLogging = "VAULT_ENABLE_RATE_LIMIT_AUDIT_LOGGING"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rateLimitExemptPaths.AddPaths([]string{
|
||||
"/v1/sys/generate-recovery-token/attempt",
|
||||
"/v1/sys/generate-recovery-token/update",
|
||||
"/v1/sys/generate-root/attempt",
|
||||
"/v1/sys/generate-root/update",
|
||||
"/v1/sys/health",
|
||||
"/v1/sys/seal-status",
|
||||
"/v1/sys/unseal",
|
||||
})
|
||||
}
|
||||
|
||||
// ClientRateLimiter defines a token bucket based rate limiter for a unique
|
||||
// addressable client (e.g. IP address). Whenever this client attempts to make
|
||||
// a request, the lastSeen value will be updated.
|
||||
type ClientRateLimiter struct {
|
||||
// lastSeen defines the UNIX timestamp the client last made a request.
|
||||
lastSeen time.Time
|
||||
|
||||
// limiter represents an instance of a token bucket based rate limiter.
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
// newClientRateLimiter returns a token bucket based rate limiter for a client
|
||||
// that is uniquely addressable, where maxRequests defines the requests-per-second
|
||||
// and burstSize defines the maximum burst allowed. A caller may provide -1 for
|
||||
// burstSize to allow the burst value to be roughly equivalent to the RPS. Note,
|
||||
// the underlying rate limiter is already thread-safe.
|
||||
func newClientRateLimiter(maxRequests float64, burstSize int) *ClientRateLimiter {
|
||||
if burstSize < 0 {
|
||||
burstSize = int(math.Ceil(maxRequests))
|
||||
}
|
||||
|
||||
return &ClientRateLimiter{
|
||||
lastSeen: time.Now().UTC(),
|
||||
limiter: rate.NewLimiter(rate.Limit(maxRequests), burstSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that RateLimitQuota implements the Quota interface
|
||||
var _ Quota = (*RateLimitQuota)(nil)
|
||||
|
||||
// RateLimitQuota represents the quota rule properties that is used to limit the
|
||||
// number of requests per second for a namespace or mount.
|
||||
type RateLimitQuota struct {
|
||||
// ID is the identifier of the quota
|
||||
ID string `json:"id"`
|
||||
|
||||
// Type of quota this represents
|
||||
Type Type `json:"type"`
|
||||
|
||||
// Name of the quota rule
|
||||
Name string `json:"name"`
|
||||
|
||||
// NamespacePath is the path of the namespace to which this quota is
|
||||
// applicable.
|
||||
NamespacePath string `json:"namespace_path"`
|
||||
|
||||
// MountPath is the path of the mount to which this quota is applicable
|
||||
MountPath string `json:"mount_path"`
|
||||
|
||||
// Rate defines the rate of which allowed requests are refilled per second.
|
||||
Rate float64 `json:"rate"`
|
||||
|
||||
// Burst defines maximum number of requests at any given moment to be allowed.
|
||||
Burst int `json:"burst"`
|
||||
|
||||
lock *sync.Mutex
|
||||
logger log.Logger
|
||||
metricSink *metricsutil.ClusterMetricSink
|
||||
purgeEnabled bool
|
||||
|
||||
// purgeInterval defines the interval in seconds in which the RateLimitQuota
|
||||
// attempts to remove stale entries from the rateQuotas mapping.
|
||||
purgeInterval time.Duration
|
||||
closeCh chan struct{}
|
||||
|
||||
// staleAge defines the age in seconds in which a clientRateLimiter is
|
||||
// considered stale. A clientRateLimiter is considered stale if the delta
|
||||
// between the current purge time and its lastSeen timestamp is greater than
|
||||
// this value.
|
||||
staleAge time.Duration
|
||||
|
||||
// rateQuotas contains a mapping from a unique addressable client (e.g. IP address)
|
||||
// to a clientRateLimiter reference. Every purgeInterval seconds, the RateLimitQuota
|
||||
// will attempt to remove stale entries from the mapping.
|
||||
rateQuotas map[string]*ClientRateLimiter
|
||||
}
|
||||
|
||||
// NewRateLimitQuota creates a quota checker for imposing limits on the number
|
||||
// of requests per second.
|
||||
func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, burst int) *RateLimitQuota {
|
||||
return &RateLimitQuota{
|
||||
Name: name,
|
||||
Type: TypeRateLimit,
|
||||
NamespacePath: nsPath,
|
||||
MountPath: mountPath,
|
||||
Rate: rate,
|
||||
Burst: burst,
|
||||
}
|
||||
}
|
||||
|
||||
// jnitialize ensures the namespace and max requests are initialized, sets the ID
|
||||
// if it's currently empty, sets the purge interval and stale age to default
|
||||
// values, and finally starts the client purge go routine if it has been started
|
||||
// already. Note, initialize will reset the internal rateQuotas mapping.
|
||||
func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.ClusterMetricSink) error {
|
||||
if rlq.lock == nil {
|
||||
rlq.lock = new(sync.Mutex)
|
||||
}
|
||||
|
||||
rlq.lock.Lock()
|
||||
defer rlq.lock.Unlock()
|
||||
|
||||
// Memdb requires a non-empty value for indexing
|
||||
if rlq.NamespacePath == "" {
|
||||
rlq.NamespacePath = "root"
|
||||
}
|
||||
|
||||
if rlq.Rate <= 0 {
|
||||
return fmt.Errorf("invalid avg rps: %v", rlq.Rate)
|
||||
}
|
||||
|
||||
if rlq.Burst < int(rlq.Rate) {
|
||||
return fmt.Errorf("burst size (%v) must be greater than or equal to average rps (%v)", rlq.Burst, rlq.Rate)
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
rlq.logger = logger
|
||||
}
|
||||
|
||||
if rlq.metricSink == nil {
|
||||
rlq.metricSink = ms
|
||||
}
|
||||
|
||||
if rlq.ID == "" {
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rlq.ID = id
|
||||
}
|
||||
|
||||
rlq.purgeInterval = DefaultRateLimitPurgeInterval
|
||||
rlq.staleAge = DefaultRateLimitStaleAge
|
||||
rlq.rateQuotas = make(map[string]*ClientRateLimiter)
|
||||
|
||||
if !rlq.purgeEnabled {
|
||||
rlq.purgeEnabled = true
|
||||
rlq.closeCh = make(chan struct{})
|
||||
go rlq.purgeClientsLoop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// quotaID returns the identifier of the quota rule
|
||||
func (rlq *RateLimitQuota) quotaID() string {
|
||||
return rlq.ID
|
||||
}
|
||||
|
||||
// QuotaName returns the name of the quota rule
|
||||
func (rlq *RateLimitQuota) QuotaName() string {
|
||||
return rlq.Name
|
||||
}
|
||||
|
||||
// purgeClientsLoop performs a blocking process where every purgeInterval
|
||||
// duration, we look for stale clients to remove from the rateQuotas map.
|
||||
// A ClientRateLimiter is considered stale if its lastSeen timestamp exceeds the
|
||||
// current time. The loop will continue to run indefinitely until a value is
|
||||
// sent on the closeCh in which we stop the ticker and exit.
|
||||
func (rlq *RateLimitQuota) purgeClientsLoop() {
|
||||
ticker := time.NewTicker(rlq.purgeInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
rlq.lock.Lock()
|
||||
|
||||
for client, crl := range rlq.rateQuotas {
|
||||
if t.UTC().Sub(crl.lastSeen) >= rlq.staleAge {
|
||||
delete(rlq.rateQuotas, client)
|
||||
}
|
||||
}
|
||||
|
||||
rlq.lock.Unlock()
|
||||
|
||||
case <-rlq.closeCh:
|
||||
ticker.Stop()
|
||||
rlq.purgeEnabled = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clientRateLimiter returns a reference to a ClientRateLimiter based on a
|
||||
// provided client address (e.g. IP address). If the ClientRateLimiter does not
|
||||
// exist in the RateLimitQuota's mapping, one will be created and set. The
|
||||
// created RateLimitQuota will have its requests-per-second set to
|
||||
// RateLimitQuota.AverageRps. If the ClientRateLimiter already exists, the
|
||||
// lastSeen timestamp will be updated.
|
||||
func (rlq *RateLimitQuota) clientRateLimiter(addr string) *ClientRateLimiter {
|
||||
rlq.lock.Lock()
|
||||
defer rlq.lock.Unlock()
|
||||
|
||||
crl, ok := rlq.rateQuotas[addr]
|
||||
if !ok {
|
||||
limiter := newClientRateLimiter(rlq.Rate, rlq.Burst)
|
||||
rlq.rateQuotas[addr] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
crl.lastSeen = time.Now().UTC()
|
||||
return crl
|
||||
}
|
||||
|
||||
// allow decides if the request is allowed by the quota. An error will be
|
||||
// returned if the request ID or address is empty. If the path is exempt, the
|
||||
// quota will not be evaluated. Otherwise, the client rate limiter is retrieved
|
||||
// by address and the rate limit quota is checked against that limiter.
|
||||
func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
|
||||
var resp Response
|
||||
|
||||
// Skip rate limit checks for paths that are exempt from rate limiting.
|
||||
if rateLimitExemptPaths.HasPath(req.Path) {
|
||||
resp.Allowed = true
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if req.ClientAddress == "" {
|
||||
return resp, fmt.Errorf("missing request client address in quota request")
|
||||
}
|
||||
|
||||
resp.Allowed = rlq.clientRateLimiter(req.ClientAddress).limiter.Allow()
|
||||
if !resp.Allowed {
|
||||
rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// close stops the current running client purge loop.
|
||||
func (rlq *RateLimitQuota) close() error {
|
||||
close(rlq.closeCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rlq *RateLimitQuota) handleRemount(toPath string) {
|
||||
rlq.MountPath = toPath
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
package quotas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func TestNewClientRateLimiter(t *testing.T) {
|
||||
testCases := []struct {
|
||||
maxRequests float64
|
||||
burstSize int
|
||||
expectedBurst int
|
||||
}{
|
||||
{1000, -1, 1000},
|
||||
{1000, 5000, 5000},
|
||||
{16.1, -1, 17},
|
||||
{16.7, -1, 17},
|
||||
{16.7, 100, 100},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
crl := newClientRateLimiter(tc.maxRequests, tc.burstSize)
|
||||
b := crl.limiter.Burst()
|
||||
if b != tc.expectedBurst {
|
||||
t.Fatalf("unexpected burst size; expected: %d, got: %d", tc.expectedBurst, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimitQuota(t *testing.T) {
|
||||
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50)
|
||||
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !rlq.purgeEnabled {
|
||||
t.Fatal("expected rate limit quota to start purge loop")
|
||||
}
|
||||
|
||||
if rlq.purgeInterval != DefaultRateLimitPurgeInterval {
|
||||
t.Fatalf("unexpected purgeInterval; expected: %d, got: %d", DefaultRateLimitPurgeInterval, rlq.purgeInterval)
|
||||
}
|
||||
if rlq.staleAge != DefaultRateLimitStaleAge {
|
||||
t.Fatalf("unexpected staleAge; expected: %d, got: %d", DefaultRateLimitStaleAge, rlq.staleAge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitQuota_Close(t *testing.T) {
|
||||
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50)
|
||||
|
||||
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := rlq.close(); err != nil {
|
||||
t.Fatalf("unexpected error when closing: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh
|
||||
|
||||
if rlq.purgeEnabled {
|
||||
t.Fatal("expected client purging to be disabled after close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitQuota_Allow(t *testing.T) {
|
||||
rlq := &RateLimitQuota{
|
||||
Name: "test-rate-limiter",
|
||||
Type: TypeRateLimit,
|
||||
NamespacePath: "qa",
|
||||
MountPath: "/foo/bar",
|
||||
Rate: 16.7,
|
||||
Burst: 83,
|
||||
purgeEnabled: true, // to allow manual setting of purgeInterval and staleAge
|
||||
}
|
||||
|
||||
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// override value and manually start purgeClientsLoop for testing purposes
|
||||
rlq.purgeInterval = 10 * time.Second
|
||||
rlq.staleAge = 10 * time.Second
|
||||
go rlq.purgeClientsLoop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
type clientResult struct {
|
||||
atomicNumAllow *atomic.Int32
|
||||
atomicNumFail *atomic.Int32
|
||||
}
|
||||
|
||||
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
|
||||
defer wg.Done()
|
||||
|
||||
resp, err := rlq.allow(&Request{ClientAddress: addr})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Allowed {
|
||||
atomicNumAllow.Add(1)
|
||||
} else {
|
||||
atomicNumFail.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
results := make(map[string]*clientResult)
|
||||
|
||||
start := time.Now()
|
||||
end := start.Add(5 * time.Second)
|
||||
for time.Now().Before(end) {
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.%d", i)
|
||||
cr, ok := results[addr]
|
||||
if !ok {
|
||||
results[addr] = &clientResult{atomicNumAllow: atomic.NewInt32(0), atomicNumFail: atomic.NewInt32(0)}
|
||||
cr = results[addr]
|
||||
}
|
||||
|
||||
go reqFunc(addr, cr.atomicNumAllow, cr.atomicNumFail)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got, expected := len(results), len(rlq.rateQuotas); got != expected {
|
||||
t.Fatalf("unexpected number of tracked client rate limit quotas; got %d, expected; %d", got, expected)
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
|
||||
ideal := float64(rlq.Burst) + (rlq.Rate * float64(elapsed) / float64(time.Second))
|
||||
|
||||
for addr, cr := range results {
|
||||
numAllow := cr.atomicNumAllow.Load()
|
||||
numFail := cr.atomicNumFail.Load()
|
||||
|
||||
// ensure there were some failed requests for the namespace
|
||||
if numFail == 0 {
|
||||
t.Fatalf("expected some requests to fail; addr: %s, numSuccess: %d, numFail: %d, elapsed: %d", addr, numAllow, numFail, elapsed)
|
||||
}
|
||||
|
||||
// ensure that we should never get more requests than allowed for the namespace
|
||||
if want := int32(ideal + 1); numAllow > want {
|
||||
t.Fatalf("too many successful requests; addr: %s, want: %d, numSuccess: %d, numFail: %d, elapsed: %d", addr, want, numAllow, numFail, elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
// allow enough time for the client to be purged
|
||||
time.Sleep(rlq.purgeInterval * 2)
|
||||
|
||||
for addr := range results {
|
||||
rlc, ok := rlq.rateQuotas[addr]
|
||||
if ok || rlc != nil {
|
||||
t.Fatalf("expected stale client to be purged: %s", addr)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package quotas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
)
|
||||
|
||||
func TestQuotas_Precedence(t *testing.T) {
|
||||
qm, err := NewManager(logging.NewVaultLogger(log.Trace), nil, metricsutil.BlackholeSink())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setQuotaFunc := func(t *testing.T, name, nsPath, mountPath string) Quota {
|
||||
t.Helper()
|
||||
quota := NewRateLimitQuota(name, nsPath, mountPath, 10, 20)
|
||||
err := qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return quota
|
||||
}
|
||||
|
||||
checkQuotaFunc := func(t *testing.T, nsPath, mountPath string, expected Quota) {
|
||||
t.Helper()
|
||||
quota, err := qm.queryQuota(nil, &Request{
|
||||
Type: TypeRateLimit,
|
||||
NamespacePath: nsPath,
|
||||
MountPath: mountPath,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(expected, quota); len(diff) > 0 {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
// No quota present. Expect nil.
|
||||
checkQuotaFunc(t, "", "", nil)
|
||||
|
||||
// Define global quota and expect that to be returned.
|
||||
rateLimitGlobalQuota := setQuotaFunc(t, "rateLimitGlobalQuota", "", "")
|
||||
checkQuotaFunc(t, "", "", rateLimitGlobalQuota)
|
||||
|
||||
// Define a global mount specific quota and expect that to be returned.
|
||||
rateLimitGlobalMountQuota := setQuotaFunc(t, "rateLimitGlobalMountQuota", "", "testmount")
|
||||
checkQuotaFunc(t, "", "testmount", rateLimitGlobalMountQuota)
|
||||
|
||||
// Define a namespace quota and expect that to be returned.
|
||||
rateLimitNSQuota := setQuotaFunc(t, "rateLimitNSQuota", "testns", "")
|
||||
checkQuotaFunc(t, "testns", "", rateLimitNSQuota)
|
||||
|
||||
// Define a namespace mount specific quota and expect that to be returned.
|
||||
rateLimitNSMountQuota := setQuotaFunc(t, "rateLimitNSMountQuota", "testns", "testmount")
|
||||
checkQuotaFunc(t, "testns", "testmount", rateLimitNSMountQuota)
|
||||
|
||||
// Now that many quota types are defined, verify that the most specific
|
||||
// matches are returned per namespace.
|
||||
checkQuotaFunc(t, "", "", rateLimitGlobalQuota)
|
||||
checkQuotaFunc(t, "testns", "", rateLimitNSQuota)
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// +build !enterprise
|
||||
|
||||
package quotas
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
|
||||
memdb "github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
func quotaTypes() []string {
|
||||
return []string{
|
||||
TypeRateLimit.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) init(walkFunc leaseWalkFunc) {}
|
||||
|
||||
func (m *Manager) recomputeLeaseCounts(ctx context.Context, txn *memdb.Txn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) setIsPerfStandby(quota Quota) {}
|
||||
|
||||
func (m *Manager) inLeasePathCache(path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type entManager struct {
|
||||
isPerfStandby bool
|
||||
}
|
||||
|
||||
func (*entManager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type LeaseCountQuota struct {
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) allow(request *Request) (Response, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) quotaID() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) QuotaName() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) initialize(logger log.Logger, sink *metricsutil.ClusterMetricSink) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) close() error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l LeaseCountQuota) handleRemount(s string) {
|
||||
panic("implement me")
|
||||
}
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/quotas"
|
||||
uberAtomic "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
|
@ -539,7 +540,6 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
|
|||
}
|
||||
|
||||
// Create an audit trail of the response
|
||||
|
||||
if !isControlGroupRun(req) {
|
||||
switch req.Path {
|
||||
case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status":
|
||||
|
@ -708,6 +708,36 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
|
|||
}
|
||||
}
|
||||
|
||||
leaseGenerated := false
|
||||
quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{
|
||||
Path: req.Path,
|
||||
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
|
||||
NamespacePath: ns.Path,
|
||||
})
|
||||
if quotaErr != nil {
|
||||
c.logger.Error("failed to apply quota", "path", req.Path, "error", err)
|
||||
retErr = multierror.Append(retErr, quotaErr)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
|
||||
if !quotaResp.Allowed {
|
||||
if c.logger.IsTrace() {
|
||||
c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path)
|
||||
}
|
||||
|
||||
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded))
|
||||
return nil, auth, retErr
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if quotaResp.Access != nil {
|
||||
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)
|
||||
if quotaAckErr != nil {
|
||||
retErr = multierror.Append(retErr, quotaAckErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Route the request
|
||||
resp, routeErr := c.doRouting(ctx, req)
|
||||
if resp != nil {
|
||||
|
@ -827,6 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
|
|||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
leaseGenerated = true
|
||||
resp.Secret.LeaseID = leaseID
|
||||
|
||||
// Get the actual time of the lease
|
||||
|
@ -917,6 +948,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
|
|||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
leaseGenerated = true
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1073,6 +1105,46 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
|
|||
|
||||
// If the response generated an authentication, then generate the token
|
||||
if resp != nil && resp.Auth != nil {
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get namespace from context", "error", err)
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
leaseGenerated := false
|
||||
|
||||
// The request successfully authenticated itself. Run the quota checks
|
||||
// before creating lease.
|
||||
quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{
|
||||
Path: req.Path,
|
||||
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
|
||||
NamespacePath: ns.Path,
|
||||
})
|
||||
|
||||
if quotaErr != nil {
|
||||
c.logger.Error("failed to apply quota", "path", req.Path, "error", err)
|
||||
retErr = multierror.Append(retErr, quotaErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !quotaResp.Allowed {
|
||||
if c.logger.IsTrace() {
|
||||
c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path)
|
||||
}
|
||||
|
||||
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded))
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if quotaResp.Access != nil {
|
||||
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)
|
||||
if quotaAckErr != nil {
|
||||
retErr = multierror.Append(retErr, quotaAckErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var entity *identity.Entity
|
||||
auth = resp.Auth
|
||||
|
@ -1141,10 +1213,6 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
|
|||
resp.AddWarning(warning)
|
||||
}
|
||||
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInternalError
|
||||
|
@ -1181,6 +1249,9 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
|
|||
err = registerFunc(ctx, tokenTTL, req.Path, auth)
|
||||
switch {
|
||||
case err == nil:
|
||||
if auth.TokenType != logical.TokenTypeBatch {
|
||||
leaseGenerated = true
|
||||
}
|
||||
case err == ErrInternalError:
|
||||
return nil, auth, err
|
||||
default:
|
||||
|
|
|
@ -422,6 +422,14 @@ func (r *Router) MatchingSystemView(ctx context.Context, path string) logical.Sy
|
|||
return raw.(*routeEntry).backend.System()
|
||||
}
|
||||
|
||||
func (r *Router) MatchingMountByAPIPath(ctx context.Context, path string) string {
|
||||
me, _, _ := r.matchingMountEntryByPath(ctx, path, true)
|
||||
if me == nil {
|
||||
return ""
|
||||
}
|
||||
return me.Path
|
||||
}
|
||||
|
||||
// MatchingStoragePrefixByAPIPath the storage prefix for the given api path
|
||||
func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) {
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
|
|
|
@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error {
|
|||
// body must still be closed manually.
|
||||
func (r *Response) Error() error {
|
||||
// 200 to 399 are okay status codes. 429 is the code for health status of
|
||||
// standby nodes.
|
||||
if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 {
|
||||
// standby nodes, otherwise, 429 is treated as quota limit reached.
|
||||
if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,14 @@ var (
|
|||
// ErrPerfStandbyForward is returned when Vault is in a state such that a
|
||||
// perf standby cannot satisfy a request
|
||||
ErrPerfStandbyPleaseForward = errors.New("please forward to the active node")
|
||||
|
||||
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
|
||||
// count quota being exceeded.
|
||||
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
|
||||
|
||||
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
|
||||
// rate limit quota being exceeded.
|
||||
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
|
||||
)
|
||||
|
||||
type HTTPCodedError interface {
|
||||
|
|
|
@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
|||
}
|
||||
})
|
||||
if allErrors != nil {
|
||||
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
|
||||
return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors)
|
||||
}
|
||||
return codedErr.Code, errors.New(codedErr.Msg)
|
||||
}
|
||||
|
@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
|||
statusCode = http.StatusBadRequest
|
||||
case errwrap.Contains(err, ErrUpstreamRateLimited.Error()):
|
||||
statusCode = http.StatusBadGateway
|
||||
case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()):
|
||||
statusCode = http.StatusTooManyRequests
|
||||
case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()):
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -163,6 +163,17 @@ These metrics cover measurement of token, identity, and lease operations, and co
|
|||
| `vault.token.revoke-tree` | Time taken to revoke a token tree | ms | summary |
|
||||
| `vault.token.store` | Time taken to store an updated token entry without writing to the secondary index | ms | summary |
|
||||
|
||||
## Resource Quota Metrics
|
||||
|
||||
These metrics relate to rate limit and lease count quotas. Each metric comes with a label "name" identifying the specific quota.
|
||||
|
||||
| Metric | Description | Unit | Type |
|
||||
| :---------------------------- | :---------------------------------------------------------------- | :---- | :------ |
|
||||
| `quota.rate_limit.violation` | Total number of rate limit quota violations | quota | counter |
|
||||
| `quota.lease_count.violation` | Total number of lease count quota violations | quota | counter |
|
||||
| `quota.lease_count.max` | Total maximum amount of leases allowed by the lease count quota | lease | gauge |
|
||||
| `quota.lease_count.counter` | Total current amount of leases generated by the lease count quota | lease | gauge |
|
||||
|
||||
## Merkle Tree and Write Ahead Log Metrics
|
||||
|
||||
These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL)
|
||||
|
|
Loading…
Reference in New Issue