diff --git a/operator/controllers/ping_test.go b/operator/controllers/ping_test.go index dd94d030..61567162 100644 --- a/operator/controllers/ping_test.go +++ b/operator/controllers/ping_test.go @@ -29,7 +29,8 @@ func TestPingInterceptors(t *testing.T) { r.NoError(err) defer srv.Close() ctx := context.Background() - endpoints := k8s.FakeEndpointsForURL(url, ns, svcName, 2) + endpoints, err := k8s.FakeEndpointsForURL(url, ns, svcName, 2) + r.NoError(err) cl := fake.NewClientBuilder().WithObjects(endpoints).Build() r.NoError(pingInterceptors( ctx, diff --git a/pkg/k8s/fake_endpoints.go b/pkg/k8s/fake_endpoints.go index 43e0b67a..a4188c66 100644 --- a/pkg/k8s/fake_endpoints.go +++ b/pkg/k8s/fake_endpoints.go @@ -2,6 +2,7 @@ package k8s import ( "net/url" + "strconv" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -16,13 +17,36 @@ func FakeEndpointsForURL( namespace, name string, num int, -) *v1.Endpoints { - addrs := make([]v1.EndpointAddress, num) +) (*v1.Endpoints, error) { + urls := make([]*url.URL, num) for i := 0; i < num; i++ { + urls[i] = u + } + return FakeEndpointsForURLs(urls, namespace, name) +} + +// FakeEndpointsForURLs creates and returns a new +// *v1.Endpoints with a single v1.EndpointSubset in it +// that has each url in the urls parameter in it. +func FakeEndpointsForURLs( + urls []*url.URL, + namespace, + name string, +) (*v1.Endpoints, error) { + addrs := make([]v1.EndpointAddress, len(urls)) + ports := make([]v1.EndpointPort, len(urls)) + for i, u := range urls { addrs[i] = v1.EndpointAddress{ Hostname: u.Hostname(), IP: u.Hostname(), } + portInt, err := strconv.Atoi(u.Port()) + if err != nil { + return nil, err + } + ports[i] = v1.EndpointPort{ + Port: int32(portInt), + } } return &v1.Endpoints{ ObjectMeta: metav1.ObjectMeta{ @@ -32,7 +56,8 @@ func FakeEndpointsForURL( Subsets: []v1.EndpointSubset{ { Addresses: addrs, + Ports: ports, }, }, - } + }, nil } diff --git a/pkg/queue/queue_counts.go b/pkg/queue/queue_counts.go index f1cfb992..fc8ce5bc 100644 --- a/pkg/queue/queue_counts.go +++ b/pkg/queue/queue_counts.go @@ -25,6 +25,15 @@ func NewCounts() *Counts { } } +// Aggregate returns the total count across all hosts +func (q *Counts) Aggregate() int { + agg := 0 + for _, count := range q.Counts { + agg += count + } + return agg +} + // MarshalJSON implements json.Marshaler func (q *Counts) MarshalJSON() ([]byte, error) { return json.Marshal(q.Counts) diff --git a/pkg/queue/queue_counts_test.go b/pkg/queue/queue_counts_test.go new file mode 100644 index 00000000..6e877b5b --- /dev/null +++ b/pkg/queue/queue_counts_test.go @@ -0,0 +1,23 @@ +package queue + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAggregate(t *testing.T) { + r := require.New(t) + counts := NewCounts() + counts.Counts = map[string]int{ + "host1": 123, + "host2": 234, + "host3": 456, + "host4": 567, + } + expectedAgg := 0 + for _, v := range counts.Counts { + expectedAgg += v + } + r.Equal(expectedAgg, counts.Aggregate()) +} diff --git a/scaler/config.go b/scaler/config.go index 5859a097..d770a0d1 100644 --- a/scaler/config.go +++ b/scaler/config.go @@ -29,6 +29,8 @@ type config struct { // UpdateRoutingTableDur is the duration between manual // updates to the routing table. UpdateRoutingTableDur time.Duration `envconfig:"KEDA_HTTP_SCALER_ROUTING_TABLE_UPDATE_DUR" default:"100ms"` + // QueueTickDuration is the duration between queue requests + QueueTickDuration time.Duration `envconfig:"KEDA_HTTP_QUEUE_TICK_DURATION" default:"500ms"` // This will be the 'Target Pending Requests' for the interceptor TargetPendingRequestsInterceptor int `envconfig:"KEDA_HTTP_SCALER_TARGET_PENDING_REQUESTS_INTERCEPTOR" default:"100"` } diff --git a/scaler/handlers_test.go b/scaler/handlers_test.go index 926a9d35..aec25502 100644 --- a/scaler/handlers_test.go +++ b/scaler/handlers_test.go @@ -19,7 +19,8 @@ func TestIsActive(t *testing.T) { ctx := context.Background() lggr := logr.Discard() table := routing.NewTable() - ticker, pinger := newFakeQueuePinger(ctx, lggr) + ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + r.NoError(err) defer ticker.Stop() pinger.pingMut.Lock() pinger.allCounts[host] = 0 @@ -70,6 +71,36 @@ func TestGetMetricSpec(t *testing.T) { target = int64(200) ) ctx := context.Background() + // <<<<<<< HEAD + // lggr := logr.Discard() + // table := routing.NewTable() + // table.AddTarget(host, routing.NewTarget( + // "testsrv", + // 8080, + // "testdepl", + // int32(target), + // )) + // ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + // r.NoError(err) + // defer ticker.Stop() + // hdl := newImpl(lggr, pinger, table, 123, 200) + // meta := map[string]string{ + // "host": host, + // "targetPendingRequests": strconv.Itoa(int(target)), + // } + // ref := &externalscaler.ScaledObjectRef{ + // ScalerMetadata: meta, + // } + // ret, err := hdl.GetMetricSpec(ctx, ref) + // r.NoError(err) + // r.NotNil(ret) + // r.Equal(1, len(ret.MetricSpecs)) + // spec := ret.MetricSpecs[0] + // r.Equal(host, spec.MetricName) + // r.Equal(target, spec.TargetSize) + // } + // ======= + // >>>>>>> 30fb204671f165b0a251a0e50634472d2a86960d type testCase struct { name string @@ -79,6 +110,24 @@ func TestGetMetricSpec(t *testing.T) { newRoutingTableFn func() *routing.Table checker func(*testing.T, *externalscaler.GetMetricSpecResponse, error) } + // <<<<<<< HEAD + // table := routing.NewTable() + // ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + // r.NoError(err) + // defer ticker.Stop() + // hdl := newImpl(lggr, pinger, table, 123, 200) + + // // no 'host' in the ScalerObjectRef's metadata field + // res, err := hdl.GetMetrics(ctx, req) + // r.Error(err) + // r.Nil(res) + // r.Contains( + // err.Error(), + // "no 'host' field found in ScaledObject metadata", + // ) + // } + // ======= + // >>>>>>> 30fb204671f165b0a251a0e50634472d2a86960d cases := []testCase{ { @@ -141,6 +190,17 @@ func TestGetMetricSpec(t *testing.T) { }, } + // <<<<<<< HEAD + // table := routing.NewTable() + // ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + // r.NoError(err) + + // defer ticker.Stop() + // hdl := newImpl(lggr, pinger, table, 123, 200) + + // req := &externalscaler.GetMetricsRequest{ + // ScaledObjectRef: &externalscaler.ScaledObjectRef{}, + // ======= for i, c := range cases { testName := fmt.Sprintf("test case #%d: %s", i, c.name) // capture tc in scope so that we can run the below test @@ -150,7 +210,13 @@ func TestGetMetricSpec(t *testing.T) { t.Parallel() lggr := logr.Discard() table := testCase.newRoutingTableFn() - ticker, pinger := newFakeQueuePinger(ctx, lggr) + ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + if err != nil { + t.Fatalf( + "error creating new fake queue pinger and related components: %s", + err, + ) + } defer ticker.Stop() hdl := newImpl( lggr, @@ -165,6 +231,7 @@ func TestGetMetricSpec(t *testing.T) { ret, err := hdl.GetMetricSpec(ctx, &scaledObjectRef) testCase.checker(t, ret, err) }) + // >>>>>>> 30fb204671f165b0a251a0e50634472d2a86960d } } @@ -206,15 +273,52 @@ func TestGetMetrics(t *testing.T) { return nil, nil, err } + // <<<<<<< HEAD + // // create a fake interceptor + // fakeSrv, fakeSrvURL, endpoints, err := startFakeQueueEndpointServer( + // ns, + // svcName, + // q, + // 1, + // ) + // r.NoError(err) + // defer fakeSrv.Close() + + // table := routing.NewTable() + // // create a fake queue pinger. this is the simulated + // // scaler that pings the above fake interceptor + // ticker, pinger, err := newFakeQueuePinger( + // ctx, + // lggr, + // func(opts *fakeQueuePingerOpts) { opts.endpoints = endpoints }, + // func(opts *fakeQueuePingerOpts) { opts.tickDur = 1 * time.Millisecond }, + // func(opts *fakeQueuePingerOpts) { opts.port = fakeSrvURL.Port() }, + // ) + // r.NoError(err) + // defer ticker.Stop() + // // start the pinger watch loop + // go func() { + + // pinger.start(ctx, ticker) + // }() + + // // sleep for more than enough time for the pinger to do its + // // first tick + // time.Sleep(50 * time.Millisecond) + // ======= // create a fake queue pinger. this is the simulated // scaler that pings the above fake interceptor - ticker, pinger := newFakeQueuePinger( + ticker, pinger, err := newFakeQueuePinger( ctx, lggr, func(opts *fakeQueuePingerOpts) { opts.endpoints = endpoints }, func(opts *fakeQueuePingerOpts) { opts.tickDur = queuePingerTickDur }, func(opts *fakeQueuePingerOpts) { opts.port = fakeSrvURL.Port() }, ) + if err != nil { + return nil, nil, err + } + // >>>>>>> 30fb204671f165b0a251a0e50634472d2a86960d // sleep for a bit to ensure the pinger has time to do its first tick time.Sleep(10 * queuePingerTickDur) @@ -233,7 +337,10 @@ func TestGetMetrics(t *testing.T) { lggr logr.Logger, ) (*routing.Table, *queuePinger, func(), error) { table := routing.NewTable() - ticker, pinger := newFakeQueuePinger(ctx, lggr) + ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + if err != nil { + return nil, nil, nil, err + } return table, pinger, func() { ticker.Stop() }, nil }, checkFn: func(t *testing.T, res *externalscaler.GetMetricsResponse, err error) { @@ -260,7 +367,10 @@ func TestGetMetrics(t *testing.T) { ) (*routing.Table, *queuePinger, func(), error) { table := routing.NewTable() // create queue and ticker without the host in it - ticker, pinger := newFakeQueuePinger(ctx, lggr) + ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + if err != nil { + return nil, nil, nil, err + } return table, pinger, func() { ticker.Stop() }, nil }, checkFn: func(t *testing.T, res *externalscaler.GetMetricsResponse, err error) { @@ -342,6 +452,59 @@ func TestGetMetrics(t *testing.T) { }, } + // <<<<<<< HEAD + // r := require.New(t) + // ctx := context.Background() + // lggr := logr.Discard() + + // // we need to create a new queuePinger with valid endpoints + // // to query this time, so that when counts are requested by + // // the internal queuePinger logic, there is a valid host from + // // which to request those counts + // q := queue.NewFakeCounter() + // // NOTE: don't call .Resize here or you'll have to make sure + // // to receive on q.ResizedCh + // q.RetMap["host1"] = pendingQLen + // q.RetMap["host2"] = pendingQLen + + // // create a fake interceptor + // fakeSrv, fakeSrvURL, endpoints, err := startFakeQueueEndpointServer( + // ns, + // svcName, + // q, + // 1, + // ) + // r.NoError(err) + // defer fakeSrv.Close() + + // table := routing.NewTable() + // // create a fake queue pinger. this is the simulated + // // scaler that pings the above fake interceptor + // const tickDur = 5 * time.Millisecond + // ticker, pinger, err := newFakeQueuePinger( + // ctx, + // lggr, + // func(opts *fakeQueuePingerOpts) { opts.endpoints = endpoints }, + // func(opts *fakeQueuePingerOpts) { opts.tickDur = tickDur }, + // func(opts *fakeQueuePingerOpts) { opts.port = fakeSrvURL.Port() }, + // ) + // r.NoError(err) + // defer ticker.Stop() + + // // sleep for more than enough time for the pinger to do its + // // first tick + // time.Sleep(tickDur * 5) + + // hdl := newImpl(lggr, pinger, table, 123, 200) + // res, err := hdl.GetMetrics(ctx, req) + // r.NoError(err) + // r.NotNil(res) + // r.Equal(1, len(res.MetricValues)) + // metricVal := res.MetricValues[0] + // r.Equal("interceptor", metricVal.MetricName) + // aggregate := pinger.aggregate() + // r.Equal(int64(aggregate), metricVal.MetricValue) + // ======= for i, c := range testCases { tc := c name := fmt.Sprintf("test case %d: %s", i, tc.name) @@ -370,4 +533,5 @@ func TestGetMetrics(t *testing.T) { tc.checkFn(t, res, err) }) } + // >>>>>>> 30fb204671f165b0a251a0e50634472d2a86960d } diff --git a/scaler/main.go b/scaler/main.go index 4a837ade..83fe89e3 100644 --- a/scaler/main.go +++ b/scaler/main.go @@ -49,19 +49,31 @@ func main() { lggr.Error(err, "getting a Kubernetes client") os.Exit(1) } - pinger := newQueuePinger( + pinger, err := newQueuePinger( context.Background(), lggr, k8s.EndpointsFuncForK8sClientset(k8sCl), namespace, svcName, targetPortStr, - time.NewTicker(500*time.Millisecond), ) + if err != nil { + lggr.Error(err, "creating a queue pinger") + os.Exit(1) + } table := routing.NewTable() grp, ctx := errgroup.WithContext(ctx) + + grp.Go(func() error { + defer done() + return pinger.start( + ctx, + time.NewTicker(cfg.QueueTickDuration), + ) + }) + grp.Go(func() error { defer done() return startGrpcServer( @@ -167,7 +179,7 @@ func startHealthcheckServer( mux.HandleFunc("/queue_ping", func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() lggr := lggr.WithName("route.counts_ping") - if err := pinger.requestCounts(ctx); err != nil { + if err := pinger.fetchAndSaveCounts(ctx); err != nil { lggr.Error(err, "requesting counts failed") w.WriteHeader(500) w.Write([]byte("error requesting counts from interceptors")) diff --git a/scaler/main_test.go b/scaler/main_test.go index 921b5e85..5621fec7 100644 --- a/scaler/main_test.go +++ b/scaler/main_test.go @@ -34,7 +34,8 @@ func TestHealthChecks(t *testing.T) { errgrp, ctx := errgroup.WithContext(ctx) - ticker, pinger := newFakeQueuePinger(ctx, lggr) + ticker, pinger, err := newFakeQueuePinger(ctx, lggr) + r.NoError(err) defer ticker.Stop() srvFunc := func() error { return startHealthcheckServer( diff --git a/scaler/queue_pinger.go b/scaler/queue_pinger.go index d874ae7f..addc9b64 100644 --- a/scaler/queue_pinger.go +++ b/scaler/queue_pinger.go @@ -11,9 +11,28 @@ import ( "github.com/go-logr/logr" "github.com/kedacore/http-add-on/pkg/k8s" "github.com/kedacore/http-add-on/pkg/queue" + "github.com/pkg/errors" "golang.org/x/sync/errgroup" ) +// queuePinger has functionality to ping all interceptors +// behind a given `Service`, fetch their pending queue counts, +// and aggregate all of those counts together. +// +// It's capable of doing that work in parallel when possible +// as well. +// +// Sample usage: +// +// pinger, err := newQueuePinger(ctx, lggr, getEndpointsFn, ns, svcName, adminPort) +// if err != nil { +// panic(err) +// } +// // make sure to start the background pinger loop. +// // you can shut this loop down by using a cancellable +// // context +// go pinger.start(ctx, ticker) +// type queuePinger struct { getEndpointsFn k8s.GetEndpointsFunc ns string @@ -33,8 +52,7 @@ func newQueuePinger( ns, svcName, adminPort string, - pingTicker *time.Ticker, -) *queuePinger { +) (*queuePinger, error) { pingMut := new(sync.RWMutex) pinger := &queuePinger{ getEndpointsFn: getEndpointsFn, @@ -44,18 +62,41 @@ func newQueuePinger( pingMut: pingMut, lggr: lggr, allCounts: map[string]int{}, + aggregateCount: 0, } + return pinger, pinger.fetchAndSaveCounts(ctx) +} - go func() { - defer pingTicker.Stop() - for range pingTicker.C { - if err := pinger.requestCounts(ctx); err != nil { +// start starts the queuePinger +func (q *queuePinger) start( + ctx context.Context, + ticker *time.Ticker, +) error { + lggr := q.lggr.WithName("scaler.queuePinger.start") + defer ticker.Stop() + for range ticker.C { + select { + case <-ctx.Done(): + lggr.Error( + ctx.Err(), + "context marked done. stopping queuePinger loop", + ) + return errors.Wrap( + ctx.Err(), + "context marked done. stopping queuePinger loop", + ) + default: + err := q.fetchAndSaveCounts(ctx) + if err != nil { lggr.Error(err, "getting request counts") + return errors.Wrap( + err, + "error getting request counts", + ) } } - }() - - return pinger + } + return nil } func (q *queuePinger) counts() map[string]int { @@ -70,25 +111,74 @@ func (q *queuePinger) aggregate() int { return q.aggregateCount } -func (q *queuePinger) requestCounts(ctx context.Context) error { - lggr := q.lggr.WithName("queuePinger.requestCounts") - - endpointURLs, err := k8s.EndpointsForService( +// fetchAndSaveCounts calls fetchCounts, and then +// saves them to internal state in q +func (q *queuePinger) fetchAndSaveCounts(ctx context.Context) error { + q.pingMut.Lock() + defer q.pingMut.Unlock() + counts, agg, err := fetchCounts( ctx, + q.lggr, + q.getEndpointsFn, q.ns, q.svcName, q.adminPort, - q.getEndpointsFn, ) if err != nil { + q.lggr.Error(err, "getting request counts") return err } + q.allCounts = counts + q.aggregateCount = agg + q.lastPingTime = time.Now() + + return nil + +} + +// fetchCounts fetches all counts from every endpoint returned +// by endpointsFn for the given service named svcName on the +// port adminPort, in namespace ns. +// +// Requests to fetch endpoints are made concurrently and +// aggregated when all requests return successfully. +// +// Upon any failure, a non-nil error is returned and the +// other two return values are nil and 0, respectively. +func fetchCounts( + ctx context.Context, + lggr logr.Logger, + endpointsFn k8s.GetEndpointsFunc, + ns, + svcName, + adminPort string, +) (map[string]int, int, error) { + lggr = lggr.WithName("queuePinger.requestCounts") + + endpointURLs, err := k8s.EndpointsForService( + ctx, + ns, + svcName, + adminPort, + endpointsFn, + ) + if err != nil { + return nil, 0, err + } countsCh := make(chan *queue.Counts) - defer close(countsCh) - fetchGrp, _ := errgroup.WithContext(ctx) + var wg sync.WaitGroup + fetchGrp, ctx := errgroup.WithContext(ctx) for _, endpoint := range endpointURLs { + // capture the endpoint in a loop-local + // variable so that the goroutine can + // use it u := endpoint + // have the errgroup goroutine send to + // a "private" goroutine, which we'll + // then forward on to countsCh + ch := make(chan *queue.Counts) + wg.Add(1) fetchGrp.Go(func() error { counts, err := queue.GetCounts( ctx, @@ -105,43 +195,44 @@ func (q *queuePinger) requestCounts(ctx context.Context) error { ) return err } - countsCh <- counts + ch <- counts return nil }) + // forward the "private" goroutine + // on to countsCh separately + go func() { + defer wg.Done() + res := <-ch + countsCh <- res + }() } - // consume the results of the counts channel in a goroutine. - // we'll must for all the fetcher goroutines to finish after we - // start up this goroutine so that all goroutines can make - // progress + // close countsCh after all goroutines are done sending + // to their "private" channels, so that we can range + // over countsCh normally below go func() { - agg := 0 - totalCounts := make(map[string]int) - // range through the result of each endpoint - for count := range countsCh { - // each endpoint returns a map of counts, one count - // per host. add up the counts for each host - for host, val := range count.Counts { - agg += val - totalCounts[host] += val - } - } - - q.pingMut.Lock() - defer q.pingMut.Unlock() - q.allCounts = totalCounts - q.aggregateCount = agg - q.lastPingTime = time.Now() + wg.Wait() + close(countsCh) }() - // now that the counts channel is being consumed, all the - // fetch goroutines can make progress. wait for them - // to finish and check for errors. if err := fetchGrp.Wait(); err != nil { lggr.Error(err, "fetching all counts failed") - return err + return nil, 0, err } - return nil + // consume the results of the counts channel + agg := 0 + totalCounts := make(map[string]int) + // range through the result of each endpoint + for count := range countsCh { + // each endpoint returns a map of counts, one count + // per host. add up the counts for each host + for host, val := range count.Counts { + agg += val + totalCounts[host] += val + } + } + + return totalCounts, agg, nil } diff --git a/scaler/queue_pinger_fake.go b/scaler/queue_pinger_fake.go index 53eca52f..5df308be 100644 --- a/scaler/queue_pinger_fake.go +++ b/scaler/queue_pinger_fake.go @@ -31,13 +31,15 @@ func startFakeQueueEndpointServer( ) (*httptest.Server, *url.URL, *v1.Endpoints, error) { hdl := http.NewServeMux() queue.AddCountsRoute(logr.Discard(), hdl, q) - srv, url, err := kedanet.StartTestServer(hdl) + srv, srvURL, err := kedanet.StartTestServer(hdl) if err != nil { return nil, nil, nil, err } - - endpoints := k8s.FakeEndpointsForURL(url, ns, svcName, numEndpoints) - return srv, url, endpoints, nil + endpoints, err := k8s.FakeEndpointsForURL(srvURL, ns, svcName, numEndpoints) + if err != nil { + return nil, nil, nil, err + } + return srv, srvURL, endpoints, nil } type fakeQueuePingerOpts struct { @@ -56,7 +58,7 @@ func newFakeQueuePinger( ctx context.Context, lggr logr.Logger, optsFuncs ...optsFunc, -) (*time.Ticker, *queuePinger) { +) (*time.Ticker, *queuePinger, error) { opts := &fakeQueuePingerOpts{ endpoints: &v1.Endpoints{}, tickDur: time.Second, @@ -66,21 +68,19 @@ func newFakeQueuePinger( optsFunc(opts) } ticker := time.NewTicker(opts.tickDur) - pinger := newQueuePinger( + + pinger, err := newQueuePinger( ctx, lggr, - func( - ctx context.Context, - namespace, - serviceName string, - ) (*v1.Endpoints, error) { + func(context.Context, string, string) (*v1.Endpoints, error) { return opts.endpoints, nil - }, "testns", "testsvc", opts.port, - ticker, ) - return ticker, pinger + if err != nil { + return nil, nil, err + } + return ticker, pinger, nil } diff --git a/scaler/queue_pinger_test.go b/scaler/queue_pinger_test.go index e1e09dbf..42e156e3 100644 --- a/scaler/queue_pinger_test.go +++ b/scaler/queue_pinger_test.go @@ -2,24 +2,23 @@ package main import ( context "context" - "net/http" "testing" "time" "github.com/go-logr/logr" - "github.com/kedacore/http-add-on/pkg/k8s" - kedanet "github.com/kedacore/http-add-on/pkg/net" "github.com/kedacore/http-add-on/pkg/queue" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" ) -func TestRequestCounts(t *testing.T) { +func TestCounts(t *testing.T) { r := require.New(t) ctx := context.Background() const ( - ns = "testns" - svcName = "testsvc" + ns = "testns" + svcName = "testsvc" + tickDur = 10 * time.Millisecond + numEndpoints = 3 ) // assemble an in-memory queue and start up a fake server that serves it. @@ -36,18 +35,15 @@ func TestRequestCounts(t *testing.T) { q.Resize(host, count) } - hdl := http.NewServeMux() - queue.AddCountsRoute(logr.Discard(), hdl, q) - srv, url, err := kedanet.StartTestServer(hdl) + srv, srvURL, endpoints, err := startFakeQueueEndpointServer( + ns, + svcName, + q, + 3, + ) r.NoError(err) defer srv.Close() - - endpoints := k8s.FakeEndpointsForURL(url, ns, svcName, 3) - // set the initial ticker to effectively never tick so that we - // can check the behavior of the pinger before the first - // tick - ticker := time.NewTicker(10000 * time.Hour) - pinger := newQueuePinger( + pinger, err := newQueuePinger( ctx, logr.Discard(), func(context.Context, string, string) (*v1.Endpoints, error) { @@ -55,41 +51,161 @@ func TestRequestCounts(t *testing.T) { }, ns, svcName, - url.Port(), - ticker, + srvURL.Port(), ) - // the pinger starts a background watch loop but won't request the counts - // before the first tick. since the first tick effectively won't - // happen (it was set to a very long duration above), there should be - // no counts right now + r.NoError(err) + // the pinger does an initial fetch, so ensure that + // the saved counts are correct retCounts := pinger.counts() - r.Equal(0, len(retCounts)) + r.Equal(len(counts), len(retCounts)) - // reset the ticker to tick practically immediately. sleep for a little - // bit to ensure that the tick occurred and the counts were successfully - // computed, then check them. - ticker.Reset(1 * time.Nanosecond) - time.Sleep(50 * time.Millisecond) + // now update the queue, start the ticker, and ensure + // that counts are updated after the first tick + q.Resize("host1", 1) + q.Resize("host2", 2) + q.Resize("host3", 3) + q.Resize("host4", 4) + ticker := time.NewTicker(tickDur) + go func() { + pinger.start(ctx, ticker) + }() + // sleep to ensure we ticked and finished calling + // fetchAndSaveCounts + time.Sleep(tickDur * 2) - // now that the tick has happened, there should be as many - // key/value pairs in the returned counts map as addresses + // now ensure that all the counts in the pinger + // are the same as in the queue, which has been updated retCounts = pinger.counts() - r.Equal(len(counts), len(retCounts)) - - // each interceptor returns the same counts, so for each host in - // the counts map, the integer count should be - // (val * # interceptors) - for retHost, retCount := range retCounts { - expectedCount, ok := counts[retHost] - r.True(ok, "unexpected host %s returned", retHost) - expectedCount *= len(endpoints.Subsets[0].Addresses) - r.Equal( - expectedCount, - retCount, - "count for host %s was not the expected %d", - retCount, - expectedCount, + expectedCounts, err := q.Current() + r.NoError(err) + r.Equal(len(expectedCounts.Counts), len(retCounts)) + for host, count := range expectedCounts.Counts { + retCount, ok := retCounts[host] + r.True( + ok, + "returned count not found for host %s", + host, ) + + // note that the returned value should be: + // (queue_count * num_endpoints) + r.Equal(count*3, retCount) + } +} + +func TestFetchAndSaveCounts(t *testing.T) { + r := require.New(t) + ctx, done := context.WithCancel(context.Background()) + defer done() + const ( + ns = "testns" + svcName = "testsvc" + adminPort = "8081" + numEndpoints = 3 + ) + counts := queue.NewCounts() + counts.Counts = map[string]int{ + "host1": 123, + "host2": 234, + "host3": 345, + } + q := queue.NewMemory() + for host, count := range counts.Counts { + q.Resize(host, count) + } + srv, srvURL, endpoints, err := startFakeQueueEndpointServer( + ns, svcName, q, numEndpoints, + ) + r.NoError(err) + defer srv.Close() + endpointsFn := func( + ctx context.Context, + ns, + svcName string, + ) (*v1.Endpoints, error) { + return endpoints, nil + } + + pinger, err := newQueuePinger( + ctx, + logr.Discard(), + endpointsFn, + ns, + svcName, + srvURL.Port(), + // time.NewTicker(1*time.Millisecond), + ) + r.NoError(err) + + r.NoError(pinger.fetchAndSaveCounts(ctx)) + + // since all endpoints serve the same counts, + // expected aggregate is individual count * # endpoints + expectedAgg := counts.Aggregate() * numEndpoints + r.Equal(expectedAgg, pinger.aggregateCount) + // again, since all endpoints serve the same counts, + // the hosts will be the same as the original counts, + // but the value is (individual count * # endpoints) + expectedCounts := counts.Counts + for host, val := range expectedCounts { + expectedCounts[host] = val * numEndpoints } + r.Equal(expectedCounts, pinger.allCounts) +} +func TestFetchCounts(t *testing.T) { + r := require.New(t) + ctx, done := context.WithCancel(context.Background()) + defer done() + const ( + ns = "testns" + svcName = "testsvc" + adminPort = "8081" + numEndpoints = 3 + ) + counts := queue.NewCounts() + counts.Counts = map[string]int{ + "host1": 123, + "host2": 234, + "host3": 345, + } + q := queue.NewMemory() + for host, count := range counts.Counts { + r.NoError(q.Resize(host, count)) + } + srv, srvURL, endpoints, err := startFakeQueueEndpointServer( + ns, svcName, q, numEndpoints, + ) + r.NoError(err) + + defer srv.Close() + endpointsFn := func( + context.Context, + string, + string, + ) (*v1.Endpoints, error) { + return endpoints, nil + } + + cts, agg, err := fetchCounts( + ctx, + logr.Discard(), + endpointsFn, + ns, + svcName, + srvURL.Port(), + ) + r.NoError(err) + // since all endpoints serve the same counts, + // expected aggregate is individual count * # endpoints + expectedAgg := counts.Aggregate() * numEndpoints + r.Equal(expectedAgg, agg) + // again, since all endpoints serve the same counts, + // the hosts will be the same as the original counts, + // but the value is (individual count * # endpoints) + expectedCounts := counts.Counts + for host, val := range expectedCounts { + expectedCounts[host] = val * numEndpoints + } + r.Equal(expectedCounts, cts) }