Skip to content

Commit

Permalink
Merge pull request #577 from evan2645/ratelimit-node-api
Browse files Browse the repository at this point in the history
Add ratelimiter for node api
  • Loading branch information
evan2645 authored Sep 11, 2018
2 parents 96cccc2 + b41169d commit dd55fcc
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 17 deletions.
9 changes: 9 additions & 0 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions pkg/agent/manager/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ func (m *manager) processEntryRequests(ctx context.Context, entryRequests entryR
return nil
}

// Truncate the number of entry requests we are making if it exceeds the CSR
// burst limit. The rest of the requests will be made on the next pass
if len(entryRequests) > node.CSRLimit {
entryRequests.truncate(node.CSRLimit)
}

_, svids, err := m.fetchUpdates(ctx, entryRequests)
if err != nil {
return err
Expand Down Expand Up @@ -189,6 +195,17 @@ func (er entryRequests) add(e *entryRequest) {
er[entryID] = e
}

func (er entryRequests) truncate(limit int) {
counter := 1
for id := range er {
if counter > limit {
delete(er, id)
}

counter++
}
}

func parseBundles(bundles map[string]*node.Bundle) (map[string][]*x509.Certificate, error) {
out := make(map[string][]*x509.Certificate)
for _, bundle := range bundles {
Expand Down
23 changes: 21 additions & 2 deletions pkg/server/endpoints/node/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"github.com/spiffe/spire/proto/server/nodeattestor"
"github.com/spiffe/spire/proto/server/noderesolver"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

type HandlerConfig struct {
Expand All @@ -33,7 +35,8 @@ type HandlerConfig struct {
}

type Handler struct {
c HandlerConfig
c HandlerConfig
limiter Limiter

// test hooks
hooks struct {
Expand All @@ -43,7 +46,8 @@ type Handler struct {

func NewHandler(config HandlerConfig) *Handler {
h := &Handler{
c: config,
c: config,
limiter: NewLimiter(config.Log),
}
h.hooks.now = time.Now
return h
Expand All @@ -61,6 +65,11 @@ func (h *Handler) Attest(stream node.Node_AttestServer) (err error) {
return err
}

err = h.limiter.Limit(ctx, AttestMsg, 1)
if err != nil {
return status.Error(codes.ResourceExhausted, err.Error())
}

baseSpiffeIDFromCSR, err := getSpiffeIDFromCSR(request.Csr)
if err != nil {
h.c.Log.Error(err)
Expand Down Expand Up @@ -174,6 +183,11 @@ func (h *Handler) FetchX509SVID(server node.Node_FetchX509SVIDServer) (err error

ctx := server.Context()

err = h.limiter.Limit(ctx, CSRMsg, len(request.Csrs))
if err != nil {
return status.Error(codes.ResourceExhausted, err.Error())
}

peerCert, err := h.getCertFromCtx(ctx)
if err != nil {
h.c.Log.Error(err)
Expand Down Expand Up @@ -219,6 +233,11 @@ func (h *Handler) FetchX509SVID(server node.Node_FetchX509SVIDServer) (err error
}

func (h *Handler) FetchJWTSVID(ctx context.Context, req *node.FetchJWTSVIDRequest) (*node.FetchJWTSVIDResponse, error) {
err := h.limiter.Limit(ctx, JSRMsg, 1)
if err != nil {
return nil, status.Error(codes.ResourceExhausted, err.Error())
}

peerCert, err := h.getCertFromCtx(ctx)
if err != nil {
h.c.Log.Error(err)
Expand Down
86 changes: 71 additions & 15 deletions pkg/server/endpoints/node/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"encoding/pem"
"io"
"io/ioutil"
"net"
"net/url"
"path"
"sync"
"testing"
"time"

Expand All @@ -24,7 +26,6 @@ import (
"github.com/spiffe/spire/test/fakes/fakeserverca"
"github.com/spiffe/spire/test/fakes/fakeservercatalog"
"github.com/spiffe/spire/test/fakes/fakeupstreamca"
"github.com/spiffe/spire/test/mock/common/context"
"github.com/spiffe/spire/test/mock/proto/api/node"
"github.com/spiffe/spire/test/mock/proto/server/datastore"
"github.com/spiffe/spire/test/mock/proto/server/nodeattestor"
Expand All @@ -49,11 +50,11 @@ type HandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
handler *Handler
limiter *fakeLimiter
mockDataStore *mock_datastore.MockDataStore
mockServerCA *mock_ca.MockServerCA
mockNodeAttestor *mock_nodeattestor.MockNodeAttestor
mockNodeResolver *mock_noderesolver.MockNodeResolver
mockContext *mock_context.MockContext
server *mock_node.MockNode_FetchX509SVIDServer
now time.Time
}
Expand All @@ -64,11 +65,11 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
mockCtrl := gomock.NewController(t)
suite.ctrl = mockCtrl
log, _ := test.NewNullLogger()
suite.limiter = new(fakeLimiter)
suite.mockDataStore = mock_datastore.NewMockDataStore(mockCtrl)
suite.mockServerCA = mock_ca.NewMockServerCA(mockCtrl)
suite.mockNodeAttestor = mock_nodeattestor.NewMockNodeAttestor(mockCtrl)
suite.mockNodeResolver = mock_noderesolver.NewMockNodeResolver(mockCtrl)
suite.mockContext = mock_context.NewMockContext(mockCtrl)
suite.server = mock_node.NewMockNode_FetchX509SVIDServer(suite.ctrl)
suite.now = time.Now()

Expand All @@ -86,26 +87,29 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
suite.handler.hooks.now = func() time.Time {
return suite.now
}
suite.handler.limiter = suite.limiter
return suite
}

func TestAttest(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()

ctx := peer.NewContext(context.Background(), getFakePeer())
data := getAttestTestData()
setAttestExpectations(suite, data)

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)

stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(context.Background())
stream.EXPECT().Recv().Return(data.request, nil)
stream.EXPECT().Context().Return(ctx).AnyTimes()
stream.EXPECT().Recv().Return(data.request, nil).AnyTimes()

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)
stream.EXPECT().Send(&node.AttestResponse{
SvidUpdate: expected,
})
}).AnyTimes()

setAttestExpectations(suite, data)
suite.NoError(suite.handler.Attest(stream))
suite.Equal(1, suite.limiter.callsFor(AttestMsg))
}

func TestAttestChallengeResponse(t *testing.T) {
Expand All @@ -121,8 +125,9 @@ func TestAttestChallengeResponse(t *testing.T) {

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)

ctx := peer.NewContext(context.Background(), getFakePeer())
stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(context.Background())
stream.EXPECT().Context().Return(ctx)
stream.EXPECT().Recv().Return(data.request, nil)
stream.EXPECT().Send(&node.AttestResponse{
Challenge: []byte("1+1"),
Expand Down Expand Up @@ -155,6 +160,10 @@ func TestFetchX509SVID(t *testing.T) {
t.Errorf("Error was not expected\n Got: %v\n Want: %v\n", err, nil)
}

limiterCalls := suite.limiter.callsFor(CSRMsg)
if len(data.request.Csrs) != limiterCalls {
t.Errorf("expected %v calls to limiter; got %v", len(data.request.Csrs), limiterCalls)
}
}

func TestFetchX509SVIDWithRotation(t *testing.T) {
Expand Down Expand Up @@ -553,11 +562,10 @@ func setFetchX509SVIDExpectations(
caCert, _, err := util.LoadCAFixture()
require.NoError(suite.T(), err)

suite.server.EXPECT().Context().Return(suite.mockContext)
ctx := peer.NewContext(context.Background(), getFakePeer())
suite.server.EXPECT().Context().Return(ctx)
suite.server.EXPECT().Recv().Return(data.request, nil)

suite.mockContext.EXPECT().Value(gomock.Any()).Return(getFakePeer())

// begin FetchRegistrationEntries()

suite.mockDataStore.EXPECT().
Expand Down Expand Up @@ -695,8 +703,9 @@ func getFakePeer() *peer.Peer {
PeerCertificates: []*x509.Certificate{parsedCert},
}

addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
fakePeer := &peer.Peer{
Addr: nil,
Addr: addr,
AuthInfo: credentials.TLSInfo{State: state},
}

Expand Down Expand Up @@ -747,8 +756,15 @@ func TestFetchJWTSVID(t *testing.T) {
TrustDomain: testTrustDomain,
})

limiter := new(fakeLimiter)
handler.limiter = limiter

// no peer certificate on context
resp, err := handler.FetchJWTSVID(context.Background(), &node.FetchJWTSVIDRequest{})
badPeer := getFakePeer()
badPeer.AuthInfo = nil
badCtx := peer.NewContext(context.Background(), badPeer)
resp, err := handler.FetchJWTSVID(badCtx, &node.FetchJWTSVIDRequest{})
require.Equal(t, 1, limiter.callsFor(JSRMsg))
require.EqualError(t, err, "client SVID is required for this request")
require.Nil(t, resp)

Expand Down Expand Up @@ -827,3 +843,43 @@ func TestFetchJWTSVID(t *testing.T) {
},
}, resp.Bundles)
}

type fakeLimiter struct {
callsForAttest int
callsForCSR int
callsForJSR int

mtx sync.Mutex
}

func (fl *fakeLimiter) Limit(_ context.Context, msgType, count int) error {
fl.mtx.Lock()
defer fl.mtx.Unlock()

switch msgType {
case AttestMsg:
fl.callsForAttest += count
case CSRMsg:
fl.callsForCSR += count
case JSRMsg:
fl.callsForJSR += count
}

return nil
}

func (fl *fakeLimiter) callsFor(msgType int) int {
fl.mtx.Lock()
defer fl.mtx.Unlock()

switch msgType {
case AttestMsg:
return fl.callsForAttest
case CSRMsg:
return fl.callsForCSR
case JSRMsg:
return fl.callsForJSR
}

return 0
}
Loading

0 comments on commit dd55fcc

Please sign in to comment.