diff --git a/pkg/autoscaler/multiscaler.go b/pkg/autoscaler/multiscaler.go index 385466ff19bd..0eed9037a6f0 100644 --- a/pkg/autoscaler/multiscaler.go +++ b/pkg/autoscaler/multiscaler.go @@ -76,6 +76,7 @@ type UniScalerFactory func(*Metric, *DynamicConfig) (UniScaler, error) type scalerRunner struct { scaler UniScaler stopCh chan struct{} + pokeCh chan struct{} // mux guards access to metric mux sync.RWMutex @@ -147,6 +148,17 @@ func (m *MultiScaler) Get(ctx context.Context, namespace, name string) (*Metric, return (&scaler.metric).DeepCopy(), nil } +// SetScale directly sets the scale for a given metric key. This does not perform any ticking +// or updating of other scaler components. +func (m *MultiScaler) SetScale(metricKey string, scale int32) bool { + scaler, exists := m.scalers[metricKey] + if !exists { + return false + } + scaler.updateLatestScale(scale) + return true +} + // Create instantiates the desired Metric. func (m *MultiScaler) Create(ctx context.Context, metric *Metric) (*Metric, error) { m.scalersMutex.Lock() @@ -229,6 +241,7 @@ func (m *MultiScaler) createScaler(ctx context.Context, metric *Metric) (*scaler scaler: scaler, stopCh: stopCh, metric: *metric, + pokeCh: make(chan struct{}), } runner.metric.Status.DesiredScale = -1 @@ -247,6 +260,8 @@ func (m *MultiScaler) createScaler(ctx context.Context, metric *Metric) (*scaler return case <-ticker.C: m.tickScaler(ctx, scaler, scaleChan) + case <-runner.pokeCh: + m.tickScaler(ctx, scaler, scaleChan) } } }() @@ -300,6 +315,10 @@ func (m *MultiScaler) RecordStat(key string, stat Stat) { if exists { logger := m.logger.With(zap.String(logkey.Key, key)) ctx := logging.WithLogger(context.TODO(), logger) + scaler.scaler.Record(ctx, stat) + if scaler.getLatestScale() == 0 && stat.AverageConcurrentRequests != 0 { + scaler.pokeCh <- struct{}{} + } } } diff --git a/pkg/autoscaler/multiscaler_test.go b/pkg/autoscaler/multiscaler_test.go index e238d369f772..eb4141d6b5d3 100644 --- a/pkg/autoscaler/multiscaler_test.go +++ b/pkg/autoscaler/multiscaler_test.go @@ -174,6 +174,64 @@ func TestMultiScalerScaleToZero(t *testing.T) { } } +func TestMultiScalerScaleFromZero(t *testing.T) { + ctx := context.TODO() + ms, stopCh, uniScaler := createMultiScaler(t, &autoscaler.Config{ + TickInterval: time.Second * 60, + EnableScaleToZero: true, + }) + defer close(stopCh) + + metric := newMetric() + metricKey := fmt.Sprintf("%s/%s", metric.Namespace, metric.Name) + + uniScaler.setScaleResult(1, true) + + done := make(chan struct{}) + defer close(done) + ms.Watch(func(key string) { + // When we return, let the main process know. + defer func() { + done <- struct{}{} + }() + if key != metricKey { + t.Errorf("Watch() = %v, wanted %v", key, metricKey) + } + m, err := ms.Get(ctx, metric.Namespace, metric.Name) + if err != nil { + t.Errorf("Get() = %v", err) + } + if got, want := m.Status.DesiredScale, int32(1); got != want { + t.Errorf("Get() = %v, wanted %v", got, want) + } + }) + + _, err := ms.Create(ctx, metric) + if err != nil { + t.Errorf("Create() = %v", err) + } + if ok := ms.SetScale(autoscaler.NewMetricKey(metric.Namespace, metric.Name), 0); !ok { + t.Error("Failed to set scale for metric to 0") + } + + now := time.Now() + testStat := autoscaler.Stat{ + Time: &now, + PodName: "test-pod", + AverageConcurrentRequests: 1, + RequestCount: 1, + } + ms.RecordStat(testKPAKey, testStat) + + // Verify that we see a "tick" on scale from zero + select { + case <-done: + // We got the signal! + case <-time.After(30 * time.Millisecond): + t.Fatalf("timed out waiting for Watch()") + } +} + func TestMultiScalerWithoutScaleToZero(t *testing.T) { ctx := context.TODO() ms, stopCh, uniScaler := createMultiScaler(t, &autoscaler.Config{