diff --git a/README.md b/README.md index f921f82..ce1901e 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,15 @@ See [Netflix concurrency-limits](https://github.com/Netflix/concurrency-limits) for the inspiration. +## Why Use This Project? + +**Throttle Proxy** is the best solution for ensuring the safety and stability of your distributed systems during load spikes. Here's why: + +- **Proven Algorithm**: Utilizes the Additive Increase/Multiplicative Decrease (AIMD) algorithm, inspired by TCP congestion control, to dynamically adjust request concurrency. +- **Real-Time Metrics**: Leverages Prometheus metrics to make real-time decisions, ensuring your system adapts quickly to changing loads. +- **Configurable and Flexible**: Allows you to set custom thresholds and monitor multiple signals, providing fine-grained control over your traffic management. +- **Prevents Failures**: Helps prevent cascading failures and maintains system stability under unpredictable load conditions. + ## Key Features - 📊 **Adaptive Traffic Management**: Automatically adjusts request concurrency based on real-time Prometheus metrics diff --git a/internal/util/map.go b/internal/util/map.go new file mode 100644 index 0000000..6fe320d --- /dev/null +++ b/internal/util/map.go @@ -0,0 +1,44 @@ +// Package util holds custom structs and functions to handle common operations +package util + +import "sync" + +func MapKeys[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// SyncMap is a typed sync.Map implementation +type SyncMap[K comparable, V comparable] struct { + mu sync.RWMutex + items map[K]V +} + +// NewSyncMap creates a new typed concurrent map +func NewSyncMap[K comparable, V comparable]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + items: map[K]V{}, + } +} + +// Store sets the value for a key +func (m *SyncMap[K, V]) Store(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.items[key] = value +} + +// Range calls f sequentially for each key and value in the map. +// If f returns false, range stops the iteration. +func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.items { + if !f(k, v) { + break + } + } +} diff --git a/proxymw/backpressure.go b/proxymw/backpressure.go index 9df8ed2..a1f07f8 100644 --- a/proxymw/backpressure.go +++ b/proxymw/backpressure.go @@ -2,7 +2,6 @@ package proxymw import ( "context" - "encoding/json" "errors" "fmt" "log" @@ -15,6 +14,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/common/model" + + "github.com/kevindweb/throttle-proxy/internal/util" ) const ( @@ -22,7 +23,6 @@ const ( BackpressureUpdateCadence = 30 * time.Second MonitorQueryTimeout = 15 * time.Second DefaultThrottleCurve = 4.0 - InstantQueryEndpoint = "/api/v1/query" ) var ( @@ -217,7 +217,7 @@ type Backpressure struct { monitorClient *http.Client monitorURL string queries []BackpressureQuery - throttleFlags sync.Map + throttleFlags *util.SyncMap[BackpressureQuery, float64] allowance float64 client ProxyClient @@ -246,6 +246,7 @@ func NewBackpressure( warnGauge: bpQueryWarnGauge, emergencyGauge: bpQueryEmergencyGauge, queryValGauge: bpQueryValGauge, + throttleFlags: util.NewSyncMap[BackpressureQuery, float64](), monitorClient: &http.Client{ Timeout: MonitorQueryTimeout, @@ -287,62 +288,35 @@ func (bp *Backpressure) Next(rr Request) error { // preventing the other signals from actioning the congestion window. func (bp *Backpressure) metricsLoop(ctx context.Context) { for _, q := range bp.queries { - go func(query BackpressureQuery) { - bp.metricLoop(ctx, query) - }(q) - } -} - -// metricLoop pulls one PromQL metric on a loop to update whether requests should be throttled. -// we only drop the global throttle when all metrics have dropped their own throttle flag -func (bp *Backpressure) metricLoop(ctx context.Context, q BackpressureQuery) { - ticker := time.NewTicker(BackpressureUpdateCadence) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - curr, err := bp.metricFired(ctx, q.Query) - if err != nil { - bp.queryErrCount.WithLabelValues(q.Name).Inc() - log.Printf("querying metric '%s' returned error: %v", q.Query, err) - continue + go func(q BackpressureQuery) { + ticker := time.NewTicker(BackpressureUpdateCadence) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + curr, err := ValueFromPromQL(ctx, bp.monitorClient, bp.monitorURL, q.Query) + if err != nil { + bp.queryErrCount.WithLabelValues(q.Name).Inc() + log.Printf("querying metric '%s' returned error: %v", q.Query, err) + continue + } + + bp.queryValGauge.WithLabelValues(q.Name).Set(curr) + bp.updateThrottle(q, curr) + } } - - bp.queryValGauge.WithLabelValues(q.Name).Set(curr) - bp.updateThrottle(q, curr) - } + }(q) } } func (bp *Backpressure) updateThrottle(q BackpressureQuery, curr float64) { bp.throttleFlags.Store(q, q.throttlePercent(curr)) - throttlePercent := 0.0 - var err error - bp.throttleFlags.Range(func(key, value interface{}) bool { - query, ok := key.(BackpressureQuery) - if !ok { - log.Printf( - "error updating query '%s' throttle to %f: %v, expected query got %T", - q.Query, curr, err, query, - ) - return true - } - - val, ok := value.(float64) - if !ok { - bp.queryErrCount.WithLabelValues(query.Name).Inc() - log.Printf( - "error updating query '%s' throttle to %f: %v, expected float got %T", - q.Query, curr, err, val, - ) - return true - } - - throttlePercent = max(throttlePercent, val) + bp.throttleFlags.Range(func(_ BackpressureQuery, value float64) bool { + throttlePercent = max(throttlePercent, value) return true }) @@ -353,50 +327,6 @@ func (bp *Backpressure) updateThrottle(q BackpressureQuery, curr float64) { bp.mu.Unlock() } -// queryMetric checks if the PromQL expression returns a non-empty response (backpressure is firing) -func (bp *Backpressure) metricFired(ctx context.Context, query string) (float64, error) { - u, err := url.Parse(bp.monitorURL + InstantQueryEndpoint) - if err != nil { - return 0, fmt.Errorf("parse monitor URL: %w", err) - } - - q := u.Query() - q.Set("query", query) - u.RawQuery = q.Encode() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), http.NoBody) - if err != nil { - return 0, fmt.Errorf("create request: %w", err) - } - - resp, err := bp.monitorClient.Do(req) - if err != nil { - return 0, fmt.Errorf("execute request: %w", err) - } - - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return 0, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var prometheusResp PrometheusResponse - if err := json.NewDecoder(resp.Body).Decode(&prometheusResp); err != nil { - return 0, fmt.Errorf("decode response: %w", err) - } - - results := prometheusResp.Data.Result - if len(results) != 1 { - return 0, fmt.Errorf("backpressure query must return exactly one value: %s", query) - } - - res := float64(results[0].Value) - if res < 0 { - return 0, fmt.Errorf("backpressure query (%s) must have non-negative value: %f", query, res) - } - - return res, nil -} - // check ensures the number of concurrent active requests stays within the allowed window. // If the active count exceeds the current watermark, the request is denied. func (bp *Backpressure) check() error { diff --git a/proxymw/backpressure_test.go b/proxymw/backpressure_test.go index 504a90f..5aa1536 100644 --- a/proxymw/backpressure_test.go +++ b/proxymw/backpressure_test.go @@ -1,17 +1,12 @@ package proxymw import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net/http" - "sync" "testing" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" + + "github.com/kevindweb/throttle-proxy/internal/util" ) func TestBackpressureRelease(t *testing.T) { @@ -94,135 +89,6 @@ func TestBackpressureRelease(t *testing.T) { } } -func TestMetricFired(t *testing.T) { - u := "http://localhost:9090" - for _, tt := range []struct { - name string - err error - val float64 - query string - bp *Backpressure - }{ - { - name: "error response", - err: errors.New("backpressure query must return exactly one value: sum(throughput)"), - query: "sum(throughput)", - bp: &Backpressure{ - monitorClient: &http.Client{ - Transport: &Mocker{ - RoundTripFunc: func(r *http.Request) (*http.Response, error) { - return &http.Response{ - Body: io.NopCloser(bytes.NewBufferString( - `{ - "status": "success", - "data": { - "resultType": "vector", - "result": [ - { - "metric": {}, - "value": [1731988543.752, "90"] - }, - { - "metric": {}, - "value": [1731988543.752, "95"] - } - ] - } - }`)), - StatusCode: http.StatusOK, - }, nil - }, - }, - }, - }, - }, - { - name: "negative float error", - err: errors.New( - "backpressure query (sum(throughput)) must have non-negative value: -90.000000", - ), - query: "sum(throughput)", - bp: &Backpressure{ - monitorClient: &http.Client{ - Transport: &Mocker{ - RoundTripFunc: func(r *http.Request) (*http.Response, error) { - return &http.Response{ - Body: io.NopCloser(bytes.NewBufferString( - `{ - "status": "success", - "data": { - "resultType": "vector", - "result": [ - { - "metric": {}, - "value": [1731988543.752, "-90"] - } - ] - } - }`)), - StatusCode: http.StatusOK, - }, nil - }, - }, - }, - }, - }, - { - name: "bad status code throws error", - err: fmt.Errorf("unexpected status code: %d", http.StatusBadGateway), - bp: &Backpressure{ - monitorURL: u, - monitorClient: &http.Client{ - Transport: &Mocker{ - RoundTripFunc: func(_ *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusBadGateway, - }, nil - }, - }, - }, - }, - }, - { - name: "valid request and response", - query: "sum(throughput)", - val: 90, - bp: &Backpressure{ - monitorURL: u, - monitorClient: &http.Client{ - Transport: &Mocker{ - RoundTripFunc: func(r *http.Request) (*http.Response, error) { - require.Equal(t, u+InstantQueryEndpoint+"?query=sum%28throughput%29", r.URL.String()) - return &http.Response{ - Body: io.NopCloser(bytes.NewBufferString( - `{ - "status": "success", - "data": { - "resultType": "vector", - "result": [ - { - "metric": {}, - "value": [1731988543.752, "90"] - } - ] - } - }`)), - StatusCode: http.StatusOK, - }, nil - }, - }, - }, - }, - }, - } { - t.Run(tt.name, func(t *testing.T) { - val, err := tt.bp.metricFired(context.Background(), tt.query) - require.Equal(t, tt.err, err) - require.Equal(t, tt.val, val) - }) - } -} - func TestUpdateThrottle(t *testing.T) { testGauge := prometheus.NewGauge( prometheus.GaugeOpts{Name: "fake_gauge_sensitive_bp_query"}, @@ -242,7 +108,7 @@ func TestUpdateThrottle(t *testing.T) { watermark: 80, max: 100, allowance: 0.2, - throttleFlags: sync.Map{}, + throttleFlags: util.NewSyncMap[BackpressureQuery, float64](), watermarkGauge: testGauge, allowanceGauge: testGauge, }, @@ -270,12 +136,15 @@ func TestUpdateThrottle(t *testing.T) { watermark: 80, max: 100, allowance: 0.2, - throttleFlags: sync.Map{}, + throttleFlags: util.NewSyncMap[BackpressureQuery, float64](), watermarkGauge: testGauge, allowanceGauge: testGauge, }, setup: func(b *Backpressure) { - b.throttleFlags.Store("previous", 0.8) + previous := BackpressureQuery{ + Query: "previous", + } + b.throttleFlags.Store(previous, 0.8) }, query: BackpressureQuery{ Query: `sum(rate(http_requests))`, @@ -296,7 +165,8 @@ func TestUpdateThrottle(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { tt.bp.updateThrottle(tt.query, tt.update) - tt.bp.throttleFlags = sync.Map{} + tt.bp.throttleFlags = util.NewSyncMap[BackpressureQuery, float64]() + tt.expect.throttleFlags = util.NewSyncMap[BackpressureQuery, float64]() require.Equal(t, tt.expect, tt.bp) }) } diff --git a/proxymw/mock.go b/proxymw/mocker.go similarity index 100% rename from proxymw/mock.go rename to proxymw/mocker.go diff --git a/proxymw/prometheus.go b/proxymw/prometheus.go new file mode 100644 index 0000000..ba7b6d2 --- /dev/null +++ b/proxymw/prometheus.go @@ -0,0 +1,60 @@ +package proxymw + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const ( + InstantQueryEndpoint = "/api/v1/query" +) + +// ValueFromPromQL queries the prometheus instant API for the prometheus query. +// Throws an error if the response is not a single value. +func ValueFromPromQL( + ctx context.Context, client *http.Client, endpoint, query string, +) (float64, error) { + u, err := url.Parse(endpoint + InstantQueryEndpoint) + if err != nil { + return 0, fmt.Errorf("parse monitor URL: %w", err) + } + + q := u.Query() + q.Set("query", query) + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), http.NoBody) + if err != nil { + return 0, fmt.Errorf("create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("execute request: %w", err) + } + + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var prometheusResp PrometheusResponse + if err := json.NewDecoder(resp.Body).Decode(&prometheusResp); err != nil { + return 0, fmt.Errorf("decode response: %w", err) + } + + results := prometheusResp.Data.Result + if len(results) != 1 { + return 0, fmt.Errorf("backpressure query must return exactly one value: %s", query) + } + + res := float64(results[0].Value) + if res < 0 { + return 0, fmt.Errorf("backpressure query (%s) must have non-negative value: %f", query, res) + } + + return res, nil +} diff --git a/proxymw/prometheus_test.go b/proxymw/prometheus_test.go new file mode 100644 index 0000000..ff47642 --- /dev/null +++ b/proxymw/prometheus_test.go @@ -0,0 +1,139 @@ +package proxymw_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/kevindweb/throttle-proxy/proxymw" +) + +func TestMetricFired(t *testing.T) { + u := "http://localhost:9090" + for _, tt := range []struct { + name string + err error + val float64 + query string + endpoint string + client *http.Client + }{ + { + name: "error response", + err: errors.New("backpressure query must return exactly one value: sum(throughput)"), + query: "sum(throughput)", + client: &http.Client{ + Transport: &proxymw.Mocker{ + RoundTripFunc: func(r *http.Request) (*http.Response, error) { + return &http.Response{ + Body: io.NopCloser(bytes.NewBufferString( + `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": {}, + "value": [1731988543.752, "90"] + }, + { + "metric": {}, + "value": [1731988543.752, "95"] + } + ] + } + }`)), + StatusCode: http.StatusOK, + }, nil + }, + }, + }, + }, + { + name: "negative float error", + err: errors.New( + "backpressure query (sum(throughput)) must have non-negative value: -90.000000", + ), + query: "sum(throughput)", + client: &http.Client{ + Transport: &proxymw.Mocker{ + RoundTripFunc: func(r *http.Request) (*http.Response, error) { + return &http.Response{ + Body: io.NopCloser(bytes.NewBufferString( + `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": {}, + "value": [1731988543.752, "-90"] + } + ] + } + }`)), + StatusCode: http.StatusOK, + }, nil + }, + }, + }, + }, + { + name: "bad status code throws error", + err: fmt.Errorf("unexpected status code: %d", http.StatusBadGateway), + endpoint: u, + client: &http.Client{ + Transport: &proxymw.Mocker{ + RoundTripFunc: func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + }, nil + }, + }, + }, + }, + { + name: "valid request and response", + query: "sum(throughput)", + val: 90, + endpoint: u, + client: &http.Client{ + Transport: &proxymw.Mocker{ + RoundTripFunc: func(r *http.Request) (*http.Response, error) { + url := u + proxymw.InstantQueryEndpoint + "?query=sum%28throughput%29" + require.Equal(t, url, r.URL.String()) + return &http.Response{ + Body: io.NopCloser(bytes.NewBufferString( + `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": {}, + "value": [1731988543.752, "90"] + } + ] + } + }`)), + StatusCode: http.StatusOK, + }, nil + }, + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + val, err := proxymw.ValueFromPromQL(ctx, tt.client, tt.endpoint, tt.query) + require.Equal(t, tt.err, err) + require.Equal(t, tt.val, val) + }) + } +}