Add rate limiting to RPCs sent within a server instance too (#5927)
This commit is contained in:
parent
c28ace2db1
commit
e90fab0aec
|
@ -18,6 +18,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
metrics "github.com/armon/go-metrics"
|
||||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||||
"github.com/hashicorp/consul/agent/consul/fsm"
|
"github.com/hashicorp/consul/agent/consul/fsm"
|
||||||
|
@ -34,6 +35,7 @@ import (
|
||||||
"github.com/hashicorp/raft"
|
"github.com/hashicorp/raft"
|
||||||
raftboltdb "github.com/hashicorp/raft-boltdb"
|
raftboltdb "github.com/hashicorp/raft-boltdb"
|
||||||
"github.com/hashicorp/serf/serf"
|
"github.com/hashicorp/serf/serf"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// These are the protocol versions that Consul can _understand_. These are
|
// These are the protocol versions that Consul can _understand_. These are
|
||||||
|
@ -206,6 +208,10 @@ type Server struct {
|
||||||
// Enterprise user-defined areas.
|
// Enterprise user-defined areas.
|
||||||
router *router.Router
|
router *router.Router
|
||||||
|
|
||||||
|
// rpcLimiter is used to rate limit the total number of RPCs initiated
|
||||||
|
// from an agent.
|
||||||
|
rpcLimiter atomic.Value
|
||||||
|
|
||||||
// Listener is used to listen for incoming connections
|
// Listener is used to listen for incoming connections
|
||||||
Listener net.Listener
|
Listener net.Listener
|
||||||
rpcServer *rpc.Server
|
rpcServer *rpc.Server
|
||||||
|
@ -360,6 +366,8 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||||
|
|
||||||
configReplicatorConfig := ReplicatorConfig{
|
configReplicatorConfig := ReplicatorConfig{
|
||||||
Name: "Config Entry",
|
Name: "Config Entry",
|
||||||
ReplicateFn: s.replicateConfig,
|
ReplicateFn: s.replicateConfig,
|
||||||
|
@ -1028,6 +1036,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
|
||||||
args: args,
|
args: args,
|
||||||
reply: reply,
|
reply: reply,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enforce the RPC limit.
|
||||||
|
//
|
||||||
|
// "client" metric path because the internal client API is calling to the
|
||||||
|
// internal server API. It's odd that the same request directed to a server is
|
||||||
|
// recorded differently. On the other hand this possibly masks the different
|
||||||
|
// between regular client requests that traverse the network and these which
|
||||||
|
// don't (unless forwarded). This still seems most sane.
|
||||||
|
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||||
|
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||||
|
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||||
|
return structs.ErrRPCRateExceeded
|
||||||
|
}
|
||||||
if err := s.rpcServer.ServeRequest(codec); err != nil {
|
if err := s.rpcServer.ServeRequest(codec); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1039,6 +1060,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
|
||||||
func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
|
func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
|
||||||
replyFn structs.SnapshotReplyFn) error {
|
replyFn structs.SnapshotReplyFn) error {
|
||||||
|
|
||||||
|
// Enforce the RPC limit.
|
||||||
|
//
|
||||||
|
// "client" metric path because the internal client API is calling to the
|
||||||
|
// internal server API. It's odd that the same request directed to a server is
|
||||||
|
// recorded differently. On the other hand this possibly masks the different
|
||||||
|
// between regular client requests that traverse the network and these which
|
||||||
|
// don't (unless forwarded). This still seems most sane.
|
||||||
|
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||||
|
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||||
|
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||||
|
return structs.ErrRPCRateExceeded
|
||||||
|
}
|
||||||
|
|
||||||
// Perform the operation.
|
// Perform the operation.
|
||||||
var reply structs.SnapshotResponse
|
var reply structs.SnapshotResponse
|
||||||
snap, err := s.dispatchSnapshotRequest(args, in, &reply)
|
snap, err := s.dispatchSnapshotRequest(args, in, &reply)
|
||||||
|
@ -1141,6 +1175,8 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
|
||||||
// ReloadConfig is used to have the Server do an online reload of
|
// ReloadConfig is used to have the Server do an online reload of
|
||||||
// relevant configuration information
|
// relevant configuration information
|
||||||
func (s *Server) ReloadConfig(config *Config) error {
|
func (s *Server) ReloadConfig(config *Config) error {
|
||||||
|
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||||
|
|
||||||
if s.IsLeader() {
|
if s.IsLeader() {
|
||||||
// only bootstrap the config entries if we are the leader
|
// only bootstrap the config entries if we are the leader
|
||||||
// this will error if we lose leadership while bootstrapping here.
|
// this will error if we lose leadership while bootstrapping here.
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
"github.com/hashicorp/consul/types"
|
"github.com/hashicorp/consul/types"
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -988,6 +989,8 @@ func TestServer_Reload(t *testing.T) {
|
||||||
|
|
||||||
dir1, s := testServerWithConfig(t, func(c *Config) {
|
dir1, s := testServerWithConfig(t, func(c *Config) {
|
||||||
c.Build = "1.5.0"
|
c.Build = "1.5.0"
|
||||||
|
c.RPCRate = 500
|
||||||
|
c.RPCMaxBurst = 5000
|
||||||
})
|
})
|
||||||
defer os.RemoveAll(dir1)
|
defer os.RemoveAll(dir1)
|
||||||
defer s.Shutdown()
|
defer s.Shutdown()
|
||||||
|
@ -998,6 +1001,14 @@ func TestServer_Reload(t *testing.T) {
|
||||||
global_entry_init,
|
global_entry_init,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
limiter := s.rpcLimiter.Load().(*rate.Limiter)
|
||||||
|
require.Equal(t, rate.Limit(500), limiter.Limit())
|
||||||
|
require.Equal(t, 5000, limiter.Burst())
|
||||||
|
|
||||||
|
// Change rate limit
|
||||||
|
s.config.RPCRate = 1000
|
||||||
|
s.config.RPCMaxBurst = 10000
|
||||||
|
|
||||||
s.ReloadConfig(s.config)
|
s.ReloadConfig(s.config)
|
||||||
|
|
||||||
_, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal)
|
_, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal)
|
||||||
|
@ -1008,4 +1019,30 @@ func TestServer_Reload(t *testing.T) {
|
||||||
require.Equal(t, global_entry_init.Kind, global.Kind)
|
require.Equal(t, global_entry_init.Kind, global.Kind)
|
||||||
require.Equal(t, global_entry_init.Name, global.Name)
|
require.Equal(t, global_entry_init.Name, global.Name)
|
||||||
require.Equal(t, global_entry_init.Config, global.Config)
|
require.Equal(t, global_entry_init.Config, global.Config)
|
||||||
|
|
||||||
|
// Check rate limiter got updated
|
||||||
|
limiter = s.rpcLimiter.Load().(*rate.Limiter)
|
||||||
|
require.Equal(t, rate.Limit(1000), limiter.Limit())
|
||||||
|
require.Equal(t, 10000, limiter.Burst())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_RPC_RateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dir1, conf1 := testServerConfig(t)
|
||||||
|
conf1.RPCRate = 2
|
||||||
|
conf1.RPCMaxBurst = 2
|
||||||
|
s1, err := NewServer(conf1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(dir1)
|
||||||
|
defer s1.Shutdown()
|
||||||
|
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||||
|
|
||||||
|
retry.Run(t, func(r *retry.R) {
|
||||||
|
var out struct{}
|
||||||
|
if err := s1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded {
|
||||||
|
r.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue