Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ratelimiter for node api #577

Merged
merged 3 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions pkg/agent/manager/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ func (m *manager) processEntryRequests(ctx context.Context, entryRequests entryR
return nil
}

// Truncate the number of entry requests we are making if it exceeds the CSR
// burst limit. The rest of the requests will be made on the next pass
if len(entryRequests) > node.CSRLimit {
entryRequests.truncate(node.CSRLimit)
}

_, svids, err := m.fetchUpdates(ctx, entryRequests)
if err != nil {
return err
Expand Down Expand Up @@ -189,6 +195,17 @@ func (er entryRequests) add(e *entryRequest) {
er[entryID] = e
}

func (er entryRequests) truncate(limit int) {
counter := 1
for id := range er {
if counter > limit {
delete(er, id)
}

counter++
}
}

func parseBundles(bundles map[string]*node.Bundle) (map[string][]*x509.Certificate, error) {
out := make(map[string][]*x509.Certificate)
for _, bundle := range bundles {
Expand Down
23 changes: 21 additions & 2 deletions pkg/server/endpoints/node/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"github.com/spiffe/spire/proto/server/nodeattestor"
"github.com/spiffe/spire/proto/server/noderesolver"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

type HandlerConfig struct {
Expand All @@ -33,7 +35,8 @@ type HandlerConfig struct {
}

type Handler struct {
c HandlerConfig
c HandlerConfig
limiter Limiter

// test hooks
hooks struct {
Expand All @@ -43,7 +46,8 @@ type Handler struct {

func NewHandler(config HandlerConfig) *Handler {
h := &Handler{
c: config,
c: config,
limiter: NewLimiter(config.Log),
}
h.hooks.now = time.Now
return h
Expand All @@ -61,6 +65,11 @@ func (h *Handler) Attest(stream node.Node_AttestServer) (err error) {
return err
}

err = h.limiter.Limit(ctx, AttestMsg, 1)
if err != nil {
return status.Error(codes.ResourceExhausted, err.Error())
}

baseSpiffeIDFromCSR, err := getSpiffeIDFromCSR(request.Csr)
if err != nil {
h.c.Log.Error(err)
Expand Down Expand Up @@ -174,6 +183,11 @@ func (h *Handler) FetchX509SVID(server node.Node_FetchX509SVIDServer) (err error

ctx := server.Context()

err = h.limiter.Limit(ctx, CSRMsg, len(request.Csrs))
if err != nil {
return status.Error(codes.ResourceExhausted, err.Error())
}

peerCert, err := h.getCertFromCtx(ctx)
if err != nil {
h.c.Log.Error(err)
Expand Down Expand Up @@ -219,6 +233,11 @@ func (h *Handler) FetchX509SVID(server node.Node_FetchX509SVIDServer) (err error
}

func (h *Handler) FetchJWTSVID(ctx context.Context, req *node.FetchJWTSVIDRequest) (*node.FetchJWTSVIDResponse, error) {
err := h.limiter.Limit(ctx, JSRMsg, 1)
if err != nil {
return nil, status.Error(codes.ResourceExhausted, err.Error())
}

peerCert, err := h.getCertFromCtx(ctx)
if err != nil {
h.c.Log.Error(err)
Expand Down
86 changes: 71 additions & 15 deletions pkg/server/endpoints/node/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"encoding/pem"
"io"
"io/ioutil"
"net"
"net/url"
"path"
"sync"
"testing"
"time"

Expand All @@ -24,7 +26,6 @@ import (
"github.com/spiffe/spire/test/fakes/fakeserverca"
"github.com/spiffe/spire/test/fakes/fakeservercatalog"
"github.com/spiffe/spire/test/fakes/fakeupstreamca"
"github.com/spiffe/spire/test/mock/common/context"
"github.com/spiffe/spire/test/mock/proto/api/node"
"github.com/spiffe/spire/test/mock/proto/server/datastore"
"github.com/spiffe/spire/test/mock/proto/server/nodeattestor"
Expand All @@ -49,11 +50,11 @@ type HandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
handler *Handler
limiter *fakeLimiter
mockDataStore *mock_datastore.MockDataStore
mockServerCA *mock_ca.MockServerCA
mockNodeAttestor *mock_nodeattestor.MockNodeAttestor
mockNodeResolver *mock_noderesolver.MockNodeResolver
mockContext *mock_context.MockContext
server *mock_node.MockNode_FetchX509SVIDServer
now time.Time
}
Expand All @@ -64,11 +65,11 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
mockCtrl := gomock.NewController(t)
suite.ctrl = mockCtrl
log, _ := test.NewNullLogger()
suite.limiter = new(fakeLimiter)
suite.mockDataStore = mock_datastore.NewMockDataStore(mockCtrl)
suite.mockServerCA = mock_ca.NewMockServerCA(mockCtrl)
suite.mockNodeAttestor = mock_nodeattestor.NewMockNodeAttestor(mockCtrl)
suite.mockNodeResolver = mock_noderesolver.NewMockNodeResolver(mockCtrl)
suite.mockContext = mock_context.NewMockContext(mockCtrl)
suite.server = mock_node.NewMockNode_FetchX509SVIDServer(suite.ctrl)
suite.now = time.Now()

Expand All @@ -86,26 +87,29 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
suite.handler.hooks.now = func() time.Time {
return suite.now
}
suite.handler.limiter = suite.limiter
return suite
}

func TestAttest(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()

ctx := peer.NewContext(context.Background(), getFakePeer())
data := getAttestTestData()
setAttestExpectations(suite, data)

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)

stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(context.Background())
stream.EXPECT().Recv().Return(data.request, nil)
stream.EXPECT().Context().Return(ctx).AnyTimes()
stream.EXPECT().Recv().Return(data.request, nil).AnyTimes()

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)
stream.EXPECT().Send(&node.AttestResponse{
SvidUpdate: expected,
})
}).AnyTimes()

setAttestExpectations(suite, data)
suite.NoError(suite.handler.Attest(stream))
suite.Equal(1, suite.limiter.callsFor(AttestMsg))
}

func TestAttestChallengeResponse(t *testing.T) {
Expand All @@ -121,8 +125,9 @@ func TestAttestChallengeResponse(t *testing.T) {

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)

ctx := peer.NewContext(context.Background(), getFakePeer())
stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(context.Background())
stream.EXPECT().Context().Return(ctx)
stream.EXPECT().Recv().Return(data.request, nil)
stream.EXPECT().Send(&node.AttestResponse{
Challenge: []byte("1+1"),
Expand Down Expand Up @@ -155,6 +160,10 @@ func TestFetchX509SVID(t *testing.T) {
t.Errorf("Error was not expected\n Got: %v\n Want: %v\n", err, nil)
}

limiterCalls := suite.limiter.callsFor(CSRMsg)
if len(data.request.Csrs) != limiterCalls {
t.Errorf("expected %v calls to limiter; got %v", len(data.request.Csrs), limiterCalls)
}
}

func TestFetchX509SVIDWithRotation(t *testing.T) {
Expand Down Expand Up @@ -553,11 +562,10 @@ func setFetchX509SVIDExpectations(
caCert, _, err := util.LoadCAFixture()
require.NoError(suite.T(), err)

suite.server.EXPECT().Context().Return(suite.mockContext)
ctx := peer.NewContext(context.Background(), getFakePeer())
suite.server.EXPECT().Context().Return(ctx)
suite.server.EXPECT().Recv().Return(data.request, nil)

suite.mockContext.EXPECT().Value(gomock.Any()).Return(getFakePeer())

// begin FetchRegistrationEntries()

suite.mockDataStore.EXPECT().
Expand Down Expand Up @@ -695,8 +703,9 @@ func getFakePeer() *peer.Peer {
PeerCertificates: []*x509.Certificate{parsedCert},
}

addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
fakePeer := &peer.Peer{
Addr: nil,
Addr: addr,
AuthInfo: credentials.TLSInfo{State: state},
}

Expand Down Expand Up @@ -747,8 +756,15 @@ func TestFetchJWTSVID(t *testing.T) {
TrustDomain: testTrustDomain,
})

limiter := new(fakeLimiter)
handler.limiter = limiter

// no peer certificate on context
resp, err := handler.FetchJWTSVID(context.Background(), &node.FetchJWTSVIDRequest{})
badPeer := getFakePeer()
badPeer.AuthInfo = nil
badCtx := peer.NewContext(context.Background(), badPeer)
resp, err := handler.FetchJWTSVID(badCtx, &node.FetchJWTSVIDRequest{})
require.Equal(t, 1, limiter.callsFor(JSRMsg))
require.EqualError(t, err, "client SVID is required for this request")
require.Nil(t, resp)

Expand Down Expand Up @@ -827,3 +843,43 @@ func TestFetchJWTSVID(t *testing.T) {
},
}, resp.Bundles)
}

type fakeLimiter struct {
callsForAttest int
callsForCSR int
callsForJSR int

mtx sync.Mutex
}

func (fl *fakeLimiter) Limit(_ context.Context, msgType, count int) error {
fl.mtx.Lock()
defer fl.mtx.Unlock()

switch msgType {
case AttestMsg:
fl.callsForAttest += count
case CSRMsg:
fl.callsForCSR += count
case JSRMsg:
fl.callsForJSR += count
}

return nil
}

func (fl *fakeLimiter) callsFor(msgType int) int {
fl.mtx.Lock()
defer fl.mtx.Unlock()

switch msgType {
case AttestMsg:
return fl.callsForAttest
case CSRMsg:
return fl.callsForCSR
case JSRMsg:
return fl.callsForJSR
}

return 0
}
Loading