Add rate limiting to RPCs sent within a server instance too (#5927)

This commit is contained in:
Paul Banks 2019-06-13 04:26:27 -05:00 committed by GitHub
parent c28ace2db1
commit e90fab0aec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 0 deletions

View File

@ -18,6 +18,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
metrics "github.com/armon/go-metrics"
ca "github.com/hashicorp/consul/agent/connect/ca" ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot" "github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/fsm"
@ -34,6 +35,7 @@ import (
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
raftboltdb "github.com/hashicorp/raft-boltdb" raftboltdb "github.com/hashicorp/raft-boltdb"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
"golang.org/x/time/rate"
) )
// These are the protocol versions that Consul can _understand_. These are // These are the protocol versions that Consul can _understand_. These are
@ -206,6 +208,10 @@ type Server struct {
// Enterprise user-defined areas. // Enterprise user-defined areas.
router *router.Router router *router.Router
// rpcLimiter is used to rate limit the total number of RPCs initiated
// from an agent.
rpcLimiter atomic.Value
// Listener is used to listen for incoming connections // Listener is used to listen for incoming connections
Listener net.Listener Listener net.Listener
rpcServer *rpc.Server rpcServer *rpc.Server
@ -360,6 +366,8 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl
return nil, err return nil, err
} }
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
configReplicatorConfig := ReplicatorConfig{ configReplicatorConfig := ReplicatorConfig{
Name: "Config Entry", Name: "Config Entry",
ReplicateFn: s.replicateConfig, ReplicateFn: s.replicateConfig,
@ -1028,6 +1036,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
args: args, args: args,
reply: reply, reply: reply,
} }
// Enforce the RPC limit.
//
// "client" metric path because the internal client API is calling to the
// internal server API. It's odd that the same request directed to a server is
// recorded differently. On the other hand this possibly masks the different
// between regular client requests that traverse the network and these which
// don't (unless forwarded). This still seems most sane.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}
if err := s.rpcServer.ServeRequest(codec); err != nil { if err := s.rpcServer.ServeRequest(codec); err != nil {
return err return err
} }
@ -1039,6 +1060,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
replyFn structs.SnapshotReplyFn) error { replyFn structs.SnapshotReplyFn) error {
// Enforce the RPC limit.
//
// "client" metric path because the internal client API is calling to the
// internal server API. It's odd that the same request directed to a server is
// recorded differently. On the other hand this possibly masks the different
// between regular client requests that traverse the network and these which
// don't (unless forwarded). This still seems most sane.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}
// Perform the operation. // Perform the operation.
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
snap, err := s.dispatchSnapshotRequest(args, in, &reply) snap, err := s.dispatchSnapshotRequest(args, in, &reply)
@ -1141,6 +1175,8 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
// ReloadConfig is used to have the Server do an online reload of // ReloadConfig is used to have the Server do an online reload of
// relevant configuration information // relevant configuration information
func (s *Server) ReloadConfig(config *Config) error { func (s *Server) ReloadConfig(config *Config) error {
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
if s.IsLeader() { if s.IsLeader() {
// only bootstrap the config entries if we are the leader // only bootstrap the config entries if we are the leader
// this will error if we lose leadership while bootstrapping here. // this will error if we lose leadership while bootstrapping here.

View File

@ -21,6 +21,7 @@ import (
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"golang.org/x/time/rate"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -988,6 +989,8 @@ func TestServer_Reload(t *testing.T) {
dir1, s := testServerWithConfig(t, func(c *Config) { dir1, s := testServerWithConfig(t, func(c *Config) {
c.Build = "1.5.0" c.Build = "1.5.0"
c.RPCRate = 500
c.RPCMaxBurst = 5000
}) })
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s.Shutdown() defer s.Shutdown()
@ -998,6 +1001,14 @@ func TestServer_Reload(t *testing.T) {
global_entry_init, global_entry_init,
} }
limiter := s.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(500), limiter.Limit())
require.Equal(t, 5000, limiter.Burst())
// Change rate limit
s.config.RPCRate = 1000
s.config.RPCMaxBurst = 10000
s.ReloadConfig(s.config) s.ReloadConfig(s.config)
_, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal) _, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal)
@ -1008,4 +1019,30 @@ func TestServer_Reload(t *testing.T) {
require.Equal(t, global_entry_init.Kind, global.Kind) require.Equal(t, global_entry_init.Kind, global.Kind)
require.Equal(t, global_entry_init.Name, global.Name) require.Equal(t, global_entry_init.Name, global.Name)
require.Equal(t, global_entry_init.Config, global.Config) require.Equal(t, global_entry_init.Config, global.Config)
// Check rate limiter got updated
limiter = s.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(1000), limiter.Limit())
require.Equal(t, 10000, limiter.Burst())
}
func TestServer_RPC_RateLimit(t *testing.T) {
t.Parallel()
dir1, conf1 := testServerConfig(t)
conf1.RPCRate = 2
conf1.RPCMaxBurst = 2
s1, err := NewServer(conf1)
if err != nil {
t.Fatalf("err: %v", err)
}
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
retry.Run(t, func(r *retry.R) {
var out struct{}
if err := s1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded {
r.Fatalf("err: %v", err)
}
})
} }