Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limiting to RPCs sent within a server instance too #5927

Merged
merged 1 commit into from
Jun 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions agent/consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"

metrics "github.com/armon/go-metrics"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/fsm"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/hashicorp/raft"
raftboltdb "github.com/hashicorp/raft-boltdb"
"github.com/hashicorp/serf/serf"
"golang.org/x/time/rate"
)

// These are the protocol versions that Consul can _understand_. These are
Expand Down Expand Up @@ -206,6 +208,10 @@ type Server struct {
// Enterprise user-defined areas.
router *router.Router

// rpcLimiter is used to rate limit the total number of RPCs initiated
// from an agent.
rpcLimiter atomic.Value

// Listener is used to listen for incoming connections
Listener net.Listener
rpcServer *rpc.Server
Expand Down Expand Up @@ -360,6 +366,8 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl
return nil, err
}

s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))

configReplicatorConfig := ReplicatorConfig{
Name: "Config Entry",
ReplicateFn: s.replicateConfig,
Expand Down Expand Up @@ -1028,6 +1036,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
args: args,
reply: reply,
}

// Enforce the RPC limit.
//
// "client" metric path because the internal client API is calling to the
// internal server API. It's odd that the same request directed to a server is
// recorded differently. On the other hand this possibly masks the different
// between regular client requests that traverse the network and these which
// don't (unless forwarded). This still seems most sane.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}
if err := s.rpcServer.ServeRequest(codec); err != nil {
return err
}
Expand All @@ -1039,6 +1060,19 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error {
func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer,
replyFn structs.SnapshotReplyFn) error {

// Enforce the RPC limit.
//
// "client" metric path because the internal client API is calling to the
// internal server API. It's odd that the same request directed to a server is
// recorded differently. On the other hand this possibly masks the different
// between regular client requests that traverse the network and these which
// don't (unless forwarded). This still seems most sane.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !s.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}

// Perform the operation.
var reply structs.SnapshotResponse
snap, err := s.dispatchSnapshotRequest(args, in, &reply)
Expand Down Expand Up @@ -1141,6 +1175,8 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
// ReloadConfig is used to have the Server do an online reload of
// relevant configuration information
func (s *Server) ReloadConfig(config *Config) error {
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
hanshasselberg marked this conversation as resolved.
Show resolved Hide resolved

if s.IsLeader() {
// only bootstrap the config entries if we are the leader
// this will error if we lose leadership while bootstrapping here.
Expand Down
37 changes: 37 additions & 0 deletions agent/consul/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-uuid"
"golang.org/x/time/rate"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -988,6 +989,8 @@ func TestServer_Reload(t *testing.T) {

dir1, s := testServerWithConfig(t, func(c *Config) {
c.Build = "1.5.0"
c.RPCRate = 500
c.RPCMaxBurst = 5000
})
defer os.RemoveAll(dir1)
defer s.Shutdown()
Expand All @@ -998,6 +1001,14 @@ func TestServer_Reload(t *testing.T) {
global_entry_init,
}

limiter := s.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(500), limiter.Limit())
require.Equal(t, 5000, limiter.Burst())

// Change rate limit
s.config.RPCRate = 1000
s.config.RPCMaxBurst = 10000

s.ReloadConfig(s.config)

_, entry, err := s.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal)
Expand All @@ -1008,4 +1019,30 @@ func TestServer_Reload(t *testing.T) {
require.Equal(t, global_entry_init.Kind, global.Kind)
require.Equal(t, global_entry_init.Name, global.Name)
require.Equal(t, global_entry_init.Config, global.Config)

// Check rate limiter got updated
limiter = s.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(1000), limiter.Limit())
require.Equal(t, 10000, limiter.Burst())
}

func TestServer_RPC_RateLimit(t *testing.T) {
t.Parallel()
dir1, conf1 := testServerConfig(t)
conf1.RPCRate = 2
conf1.RPCMaxBurst = 2
s1, err := NewServer(conf1)
if err != nil {
t.Fatalf("err: %v", err)
}
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")

retry.Run(t, func(r *retry.R) {
var out struct{}
if err := s1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded {
r.Fatalf("err: %v", err)
}
})
}