diff --git a/functional_test.go b/functional_test.go index 6225fc3e..f604e26d 100644 --- a/functional_test.go +++ b/functional_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" @@ -34,6 +35,10 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/resolver" json "google.golang.org/protobuf/encoding/protojson" ) @@ -859,6 +864,62 @@ func TestGlobalRateLimits(t *testing.T) { }) } +func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { + owner := cluster.PeerAt(2).GRPCAddress + peer := cluster.PeerAt(0).GRPCAddress + assert.NotEqual(t, owner, peer) + + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), + } + + address := fmt.Sprintf("static:///%s,%s", owner, peer) + conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + require.NoError(t, err) + + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, assertion func(resp *guber.RateLimitResp), i int) string { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_global", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + gotResp := resp.Responses[0] + assert.Equal(t, "", gotResp.GetError(), i) + assert.Equal(t, status, gotResp.GetStatus(), i) + + if assertion != nil { + assertion(gotResp) + } + + return gotResp.GetMetadata()["owner"] + } + + // Send two hits that should be processed by the owner and the peer and deplete the limit + sendHit(guber.Status_UNDER_LIMIT, nil, 1) + sendHit(guber.Status_UNDER_LIMIT, nil, 2) + // sleep to ensure the async forward has occurred and state should be shared + time.Sleep(time.Second * 5) + + for i := 0; i < 10; i++ { + sendHit(guber.Status_OVER_LIMIT, nil, i+2) + } +} + func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { resp, err := http.Get(url) require.NoError(t, err) @@ -1273,3 +1334,46 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { } return nil } + +// staticBuilder implements the `resolver.Builder` interface. +type staticBuilder struct{} + +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + + } + r, err := newStaticResolver(cc, resolverAddrs) + if err != nil { + return nil, err + } + return r, nil +} + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func newStaticResolver(cc resolver.ClientConn, addresses []resolver.Address) (resolver.Resolver, error) { + err := cc.UpdateState(resolver.State{Addresses: addresses}) + if err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {}