diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 38d06c0..f0791c6 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -6,7 +6,6 @@ on: - "go.sum" - "go.mod" - "**.go" - - "scripts/errcheck_excludes.txt" - ".github/workflows/golangci-lint.yml" - ".golangci.yml" pull_request: diff --git a/Makefile b/Makefile index d74b103..07ffc61 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ -.PHONY: all build lint fmt lintfix test checkpath check cover test-norace deps help version docker-clean - GO = go GOIMPORTS = goimports SOURCES := $(shell find . -name '*.go') VERSION := $(shell git describe --tags --always --dirty) +.PHONY: all all: build +.PHONY: help help: @echo "Available targets:" @echo " all : Build the project" @@ -18,6 +18,7 @@ help: @echo " deps : Update dependencies" @echo " version : Show current version" +.PHONY: version version: @echo "Current version: $(VERSION)" @@ -25,51 +26,64 @@ define check_binary @command -v $(1) >/dev/null 2>&1 || { echo "Error: $(1) binary not in PATH"; exit 1; } endef +.PHONY: checkpath checkpath: $(call check_binary,go) +.PHONY: check check: checkpath $(call check_binary,golangci-lint) +.PHONY: docker-clean docker-clean: docker compose down --volumes docker compose down --rmi all +.PHONY: clean clean: docker-clean rm -f throttle-proxy +.PHONY: build build: throttle-proxy +.PHONY: throttle-proxy throttle-proxy: $(SOURCES) @echo ">> building binaries..." @$(GO) build -o $@ github.com/kevindweb/throttle-proxy +.PHONY: fmt fmt: go fmt ./... +.PHONY: lint lint: fmt @$(GOIMPORTS) -l -w -local $(shell head -n 1 go.mod | cut -d ' ' -f 2) . @golangci-lint run +.PHONY: lintfix lintfix: fmt @golangci-lint run --fix +.PHONY: ruff ruff: ruff check . TEST_FLAGS := -v -coverprofile .cover/cover.out TEST_PATH := ./... +.PHONY: test test: @echo 'Running unit tests...' @mkdir -p .cover @GOFLAGS=$(GOFLAGS) go test $(TEST_FLAGS) -race -count=10 $(TEST_PATH) +.PHONY: test-norace test-norace: @echo 'Running unit tests without race detection...' @mkdir -p .cover @GOFLAGS=$(GOFLAGS) go test $(TEST_FLAGS) $(TEST_PATH) +.PHONY: cover cover: check ifndef CI go tool cover -html .cover/cover.out @@ -77,6 +91,7 @@ else go tool cover -html .cover/cover.out -o .cover/all.html endif +.PHONY: deps deps: go get -u ./... go mod tidy diff --git a/README.md b/README.md index 6246a9c..7b8193b 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ docker compose down docker-compose up --build ``` -- Generate fake traffic with `./sandbox/traffic.py` +- Generate fake traffic with `./scripts/traffic_generator.py` - View metrics in the [local Grafana instance](http://localhost:3000/d/be68n82lvzg8wa/throttle-proxy-metrics) ### Lint and Test diff --git a/docker-compose.yaml b/docker-compose.yaml index cc66e45..b03543a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -9,7 +9,8 @@ services: -internal-listen-address=0.0.0.0:7776 -proxy-write-timeout=1m -proxy-read-timeout=1m - -unsafe-passthrough-paths=/graph,/static,/manifest.json,/api/v1 + -proxy-paths=/api/v1/query + -passthrough-paths=/graph,/static,/manifest.json -enable-jitter=true -jitter-delay=1s -enable-observer=true diff --git a/examples/README.md b/examples/README.md index 09731e3..90ffd32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,10 +4,12 @@ `go install github.com/kevindweb/throttle-proxy@latest` -## Run +### Locally `make build` +## Usage + ### Config File ``` @@ -20,7 +22,8 @@ throttle-proxy -config-file examples/config.yaml throttle-proxy -upstream=http://localhost:9095 \ -insecure-listen-address=0.0.0.0:7777 \ -internal-listen-address=0.0.0.0:7776 \ - -unsafe-passthrough-paths=/api/v2 \ + -proxy-paths=/api/v2/endpoint-to-proxy \ + -passthrough-paths=/api/v2/endpoint-to-passthrough \ -proxy-read-timeout=30s \ -proxy-write-timeout=30s \ -enable-jitter=true \ diff --git a/examples/config.yaml b/examples/config.yaml index a97298d..010d915 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -1,8 +1,10 @@ upstream: http://localhost:9095 insecure_listen_addr: 0.0.0.0:7777 internal_listen_addr: 0.0.0.0:7776 -unsafe_passthrough_paths: - - /api/v2 +proxy_paths: + - /api/v1/query +passthrough_paths: + - /favicon.ico proxy_read_timeout: 5s proxy_write_timeout: 5s proxymw_config: diff --git a/examples/roundtripper/README.md b/examples/roundtripper/README.md index 234bca6..f210583 100644 --- a/examples/roundtripper/README.md +++ b/examples/roundtripper/README.md @@ -3,12 +3,6 @@ ### Run ``` -PROXY_URL="http://localhost:9095" \ -PROXY_QUERY="vector(82)" \ -PROXY_WARN="80" \ -PROXY_EMERGENCY="100" \ -CWIND_MIN="1" \ -CWIND_MAX="100" \ -JITTER_PROXY_DELAY="1s" \ +CONFIG_FILE="examples/config.yaml" \ go run examples/roundtripper/main.go ``` diff --git a/examples/roundtripper/main.go b/examples/roundtripper/main.go index ddaa291..0aef08d 100644 --- a/examples/roundtripper/main.go +++ b/examples/roundtripper/main.go @@ -6,85 +6,48 @@ import ( "log" "net/http" "os" - "strconv" "time" + "gopkg.in/yaml.v3" + "github.com/kevindweb/throttle-proxy/proxymw" ) -func FullConfigRoundTripper() (*proxymw.RoundTripperEntry, error) { - u := os.Getenv("PROXY_URL") - if u == "" { - return nil, fmt.Errorf("empty PROXY_URL") - } - query := os.Getenv("PROXY_QUERY") - if query == "" { - return nil, fmt.Errorf("empty PROXY_QUERY") - } - warn, err := getEnvInt("PROXY_WARN") - if err != nil { - return nil, err - } - emer, err := getEnvInt("PROXY_EMERGENCY") - if err != nil { - return nil, err - } - cwndmin, err := getEnvInt("CWIND_MIN") - if err != nil { - return nil, err - } - cwndmax, err := getEnvInt("CWIND_MAX") +func fullConfigRoundTripper(ctx context.Context) (*proxymw.RoundTripperEntry, error) { + cfg, err := parseConfigFile(os.Getenv("CONFIG_FILE")) if err != nil { return nil, err } - delay := os.Getenv("JITTER_PROXY_DELAY") - jitterDelay, err := time.ParseDuration(delay) - if err != nil { - return nil, err - } - - cfg := proxymw.Config{ - BackpressureConfig: proxymw.BackpressureConfig{ - EnableBackpressure: true, - BackpressureMonitoringURL: u, - BackpressureQueries: []proxymw.BackpressureQuery{ - { - Query: query, - WarningThreshold: float64(warn), - EmergencyThreshold: float64(emer), - }, - }, - CongestionWindowMin: cwndmin, - CongestionWindowMax: cwndmax, - }, - - EnableJitter: true, - JitterDelay: jitterDelay, - - EnableObserver: true, - } - mw, err := proxymw.NewRoundTripperFromConfig(cfg, http.DefaultTransport) if err != nil { return nil, err } - mw.Init(context.Background()) + mw.Init(ctx) return mw, err } -func getEnvInt(s string) (int, error) { - a := os.Getenv(s) - i, err := strconv.Atoi(a) +func parseConfigFile(configFile string) (proxymw.Config, error) { + // nolint:gosec // accept configuration file as input + file, err := os.Open(configFile) if err != nil { - return 0, err + return proxymw.Config{}, fmt.Errorf("error opening config file: %v", err) + } + defer file.Close() + + var cfg proxymw.Config + decoder := yaml.NewDecoder(file) + if err := decoder.Decode(&cfg); err != nil { + return proxymw.Config{}, fmt.Errorf("error decoding YAML: %v", err) } - return i, nil + + return cfg, nil } func main() { - rt, err := FullConfigRoundTripper() + ctx := context.Background() + rt, err := fullConfigRoundTripper(ctx) if err != nil { log.Fatal(err) } @@ -94,13 +57,7 @@ func main() { Transport: rt, } - ctx := context.Background() - request, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - "https://google.com", - http.NoBody, - ) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://x.com", http.NoBody) if err != nil { log.Fatal(fmt.Errorf("failed to make request: %w", err)) } diff --git a/main.go b/main.go index 76b5728..f87d49d 100644 --- a/main.go +++ b/main.go @@ -2,332 +2,30 @@ package main import ( "context" - "encoding/json" "errors" - "flag" "fmt" "log" "net" "net/http" - "net/http/httputil" "net/url" "os" "os/signal" - "strconv" - "strings" "syscall" "time" "github.com/metalmatze/signal/internalserver" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/yaml.v3" _ "go.uber.org/automaxprocs" "github.com/kevindweb/throttle-proxy/proxymw" + "github.com/kevindweb/throttle-proxy/proxyutil" ) -type routes struct { - upstream *url.URL - handler http.Handler - - mux http.Handler - - logger *log.Logger -} - -func NewRoutes( - ctx context.Context, cfg proxymw.Config, passthroughs []string, upstream *url.URL, -) (*routes, error) { - proxy := httputil.NewSingleHostReverseProxy(upstream) - - r := &routes{ - upstream: upstream, - handler: proxy, - logger: log.Default(), - } - - mux := http.NewServeMux() - - mw, err := proxymw.NewServeFromConfig(cfg, r.passthrough) - if err != nil { - return nil, fmt.Errorf("failed to create middleware from config: %v", err) - } - - mw.Init(ctx) - - mux.Handle("/api/v1/query", mw.Proxy()) - mux.Handle("/api/v1/query_range", mw.Proxy()) - mux.Handle("/federate", http.HandlerFunc(r.passthrough)) - mux.Handle("/api/v1/alerts", http.HandlerFunc(r.passthrough)) - mux.Handle("/api/v1/rules", http.HandlerFunc(r.passthrough)) - mux.Handle("/api/v1/series", http.HandlerFunc(r.passthrough)) - mux.Handle("/api/v1/query_exemplars", http.HandlerFunc(r.passthrough)) - mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) - })) - - // Register optional passthrough paths. - for _, path := range passthroughs { - mux.Handle(path, http.HandlerFunc(r.passthrough)) - } - - r.mux = mux - proxy.ErrorLog = log.Default() - - return r, nil -} - -func (r *routes) ServeHTTP(w http.ResponseWriter, req *http.Request) { - r.mux.ServeHTTP(w, req) -} - -func (r *routes) passthrough(w http.ResponseWriter, req *http.Request) { - r.handler.ServeHTTP(w, req) -} - -type Config struct { - // InsecureListenAddress is the address the proxy HTTP server should listen on - InsecureListenAddress string `yaml:"insecure_listen_addr"` - // InternalListenAddress is the address the HTTP server should listen on for pprof and metrics - InternalListenAddress string `yaml:"internal_listen_addr"` - // Upstream is the upstream URL to proxy to - Upstream string `yaml:"upstream"` - UnsafePassthroughPaths []string `yaml:"unsafe_passthrough_paths"` - ProxyConfig proxymw.Config `yaml:"proxymw_config"` - ReadTimeout string `yaml:"proxy_read_timeout"` - WriteTimeout string `yaml:"proxy_write_timeout"` -} - -type StringSlice []string - -func (s *StringSlice) String() string { - return strings.Join(*s, ",") -} - -func (s *StringSlice) Set(value string) error { - *s = append(*s, value) - return nil -} - -type Float64Slice []float64 - -func (f *Float64Slice) String() string { - values := make([]string, len(*f)) - for i, v := range *f { - values[i] = fmt.Sprintf("%g", v) - } - return strings.Join(values, ",") -} - -func (f *Float64Slice) Set(value string) error { - v, err := strconv.ParseFloat(value, 64) - if err != nil { - return err - } - *f = append(*f, v) - return nil -} - -func parseConfigs() (Config, error) { - var ( - insecureListenAddress string - internalListenAddress string - readTimeout string - writeTimeout string - upstream string - unsafePassthroughPaths string - enableBackpressure bool - backpressureMonitoringURL string - backpressureQueries StringSlice - backpressureQueryNames StringSlice - backpressureWarnThresholds Float64Slice - backpressureEmergencyThresholds Float64Slice - congestionWindowMin int - congestionWindowMax int - enableJitter bool - jitterDelay time.Duration - enableObserver bool - configFile string - ) - - flagset := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - flagset.StringVar(&configFile, "config-file", "", "Config file to initialize the proxy") - flagset.StringVar( - &insecureListenAddress, "insecure-listen-address", "", - "The address the proxy HTTP server should listen on.", - ) - flagset.StringVar( - &internalListenAddress, "internal-listen-address", "", - "The address the internal HTTP server should listen on to expose metrics about itself.", - ) - flagset.StringVar( - &readTimeout, "proxy-read-timeout", (time.Minute * 5).String(), - "HTTP read timeout duration", - ) - flagset.StringVar( - &writeTimeout, "proxy-write-timeout", (time.Minute * 5).String(), - "HTTP write timeout duration", - ) - flagset.StringVar(&upstream, "upstream", "", "The upstream URL to proxy to.") - flagset.BoolVar(&enableJitter, "enable-jitter", false, "Use the jitter middleware") - flagset.DurationVar( - &jitterDelay, "jitter-delay", 0, - "Random jitter to apply when enabled", - ) - flagset.BoolVar( - &enableBackpressure, "enable-bp", false, - "Use the additive increase multiplicative decrease middleware using backpressure metrics", - ) - flagset.IntVar( - &congestionWindowMin, "bp-min-window", 0, - "Min concurrent queries to passthrough regardless of spikes in backpressure.", - ) - flagset.IntVar( - &congestionWindowMax, "bp-max-window", 0, - "Max concurrent queries to passthrough regardless of backpressure health.", - ) - flagset.StringVar( - &backpressureMonitoringURL, "bp-monitoring-url", "", - "The address on which to read backpressure metrics with PromQL queries.", - ) - flagset.Var( - &backpressureQueries, "bp-query", - "PromQL that signifies an increase in downstream failure", - ) - flagset.Var( - &backpressureQueryNames, "bp-query-name", - "Name is an optional human readable field used to emit tagged metrics. "+ - "When unset, operational metrics are omitted. "+ - `When set, read warn_threshold as proxymw_bp_warn_threshold{query_name=""}`, - ) - flagset.Var( - &backpressureWarnThresholds, "bp-warn", - "Threshold that defines when the system should start backing off", - ) - flagset.Var( - &backpressureEmergencyThresholds, "bp-emergency", - "Threshold that defines when the system should apply maximum throttling", - ) - flagset.BoolVar( - &enableObserver, "enable-observer", false, - "Collect middleware latency and error metrics", - ) - flagset.StringVar(&unsafePassthroughPaths, "unsafe-passthrough-paths", "", - "Comma delimited allow list of exact HTTP paths that should be allowed to hit "+ - "the upstream URL without any enforcement.") - if err := flagset.Parse(os.Args[1:]); err != nil { - return Config{}, err - } - - if configFile != "" { - return parseConfigFile(configFile) - } - - unsafePaths, err := parseUnsafePaths(unsafePassthroughPaths) - if err != nil { - return Config{}, err - } - - n := len(backpressureQueries) - queries := make([]proxymw.BackpressureQuery, n) - if len(backpressureQueryNames) != n && len(backpressureQueryNames) != 0 { - return Config{}, fmt.Errorf("number of backpressure query names should be 0 or %d", n) - } - - if len(backpressureWarnThresholds) != n { - return Config{}, fmt.Errorf("expected %d warn thresholds for %d backpressure queries", n, n) - } - - if len(backpressureEmergencyThresholds) != n { - return Config{}, fmt.Errorf( - "expected %d emergency thresholds for %d backpressure queries", n, n, - ) - } - - for i, query := range backpressureQueries { - queryName := "" - if len(backpressureQueryNames) > 0 { - queryName = backpressureQueryNames[i] - } - queries[i] = proxymw.BackpressureQuery{ - Name: queryName, - Query: query, - WarningThreshold: backpressureWarnThresholds[i], - EmergencyThreshold: backpressureEmergencyThresholds[i], - } - } - - return Config{ - InsecureListenAddress: insecureListenAddress, - InternalListenAddress: internalListenAddress, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - Upstream: upstream, - UnsafePassthroughPaths: unsafePaths, - ProxyConfig: proxymw.Config{ - EnableJitter: enableJitter, - JitterDelay: jitterDelay, - EnableObserver: enableObserver, - BackpressureConfig: proxymw.BackpressureConfig{ - EnableBackpressure: enableBackpressure, - BackpressureMonitoringURL: backpressureMonitoringURL, - BackpressureQueries: queries, - CongestionWindowMin: congestionWindowMin, - CongestionWindowMax: congestionWindowMax, - }, - }, - }, nil -} - -func parseUnsafePaths(unsafePaths string) ([]string, error) { - if unsafePaths == "" { - return []string{}, nil - } - - paths := strings.Split(unsafePaths, ",") - for _, path := range paths { - u, err := url.Parse(fmt.Sprintf("http://example.com%v", path)) - if err != nil { - return nil, fmt.Errorf( - "path %q is not a valid URI path, got %v", path, unsafePaths, - ) - } - if u.Path != path { - return nil, fmt.Errorf( - "path %q is not a valid URI path, got %v", path, unsafePaths, - ) - } - if u.Path == "" || u.Path == "/" { - return nil, fmt.Errorf( - "path %q is not allowed, got %v", u.Path, unsafePaths, - ) - } - } - return paths, nil -} - -func parseConfigFile(configFile string) (Config, error) { - // nolint:gosec // accept configuration file as input - file, err := os.Open(configFile) - if err != nil { - return Config{}, fmt.Errorf("error opening config file: %v", err) - } - defer file.Close() - - var cfg Config - decoder := yaml.NewDecoder(file) - if err := decoder.Decode(&cfg); err != nil { - return Config{}, fmt.Errorf("error decoding YAML: %v", err) - } - - return cfg, nil -} - func main() { - cfg, err := parseConfigs() + cfg, err := proxyutil.ParseConfigs() if err != nil { - log.Fatalf("Failed to parse config file: %v", err) + log.Fatalf("Failed to parse flags: %v", err) } ctx := context.Background() @@ -367,7 +65,7 @@ func main() { } } -func setupInsecureServer(ctx context.Context, cfg Config) (*http.Server, error) { +func setupInsecureServer(ctx context.Context, cfg proxyutil.Config) (*http.Server, error) { upstreamURL, err := parseUpstream(cfg.Upstream) if err != nil { return nil, err @@ -387,7 +85,9 @@ func setupInsecureServer(ctx context.Context, cfg Config) (*http.Server, error) cfg.ProxyConfig.ClientTimeout = 2 * readTimeout } - routes, err := NewRoutes(ctx, cfg.ProxyConfig, cfg.UnsafePassthroughPaths, upstreamURL) + routes, err := proxymw.NewRoutes( + ctx, cfg.ProxyConfig, cfg.ProxyPaths, cfg.PassthroughPaths, upstreamURL, + ) if err != nil { return nil, fmt.Errorf("failed to create proxymw Routes: %v", err) } @@ -432,7 +132,7 @@ func parseUpstream(upstream string) (*url.URL, error) { return upstreamURL, nil } -func setupInternalServer(cfg Config) (*http.Server, error) { +func setupInternalServer(cfg proxyutil.Config) (*http.Server, error) { if cfg.InternalListenAddress == "" { return nil, nil } diff --git a/proxymw/middleware.go b/proxymw/middleware.go index d41a97a..aab61d2 100644 --- a/proxymw/middleware.go +++ b/proxymw/middleware.go @@ -99,7 +99,7 @@ type ServeEntry struct { // NewServeFromConfig constructs a middleware chain based on configuration. // The middleware chain is constructed in the following order: -// 1. HTTP Request wrapping (Entry) +// 1. Request wrapping (Entry) // 2. Metrics collection (Observer) // 3. Request spreading (Jitter) // 4. Adaptive rate limiting (Backpressure) diff --git a/proxymw/routes.go b/proxymw/routes.go new file mode 100644 index 0000000..18ab41f --- /dev/null +++ b/proxymw/routes.go @@ -0,0 +1,63 @@ +package proxymw + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" +) + +type routes struct { + upstream *url.URL + handler http.Handler + + mux http.Handler +} + +func NewRoutes( + ctx context.Context, cfg Config, proxyPaths, passthroughPaths []string, upstream *url.URL, +) (http.Handler, error) { + proxy := httputil.NewSingleHostReverseProxy(upstream) + proxy.ErrorLog = log.Default() + + r := &routes{ + upstream: upstream, + handler: proxy, + } + + mw, err := NewServeFromConfig(cfg, r.passthrough) + if err != nil { + return nil, fmt.Errorf("failed to create middleware from config: %v", err) + } + + mw.Init(ctx) + + mux := http.NewServeMux() + mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewEncoder(w).Encode(map[string]bool{"ok": true}); err != nil { + log.Printf("error writing healthz endpoint: %v", err) + } + })) + + for _, path := range proxyPaths { + mux.Handle(path, mw.Proxy()) + } + + for _, path := range passthroughPaths { + mux.Handle(path, http.HandlerFunc(r.passthrough)) + } + + r.mux = mux + return r, nil +} + +func (r *routes) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.mux.ServeHTTP(w, req) +} + +func (r *routes) passthrough(w http.ResponseWriter, req *http.Request) { + r.handler.ServeHTTP(w, req) +} diff --git a/proxymw/routes_test.go b/proxymw/routes_test.go new file mode 100644 index 0000000..8edccb6 --- /dev/null +++ b/proxymw/routes_test.go @@ -0,0 +1,100 @@ +package proxymw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/kevindweb/throttle-proxy/proxymw" +) + +func TestInvalidJitterConfig(t *testing.T) { + upstream, err := url.Parse("http://google.com") + require.NoError(t, err) + + ctx := context.Background() + cfg := proxymw.Config{ + EnableJitter: true, + JitterDelay: 0, + } + + routes, err := proxymw.NewRoutes(ctx, cfg, []string{}, []string{}, upstream) + require.ErrorAs(t, err, &proxymw.ErrJitterDelayRequired) + require.Nil(t, routes) +} + +func TestNewRoutes(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream response")) + })) + defer upstream.Close() + + upstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("Failed to parse upstream URL: %v", err) + } + + cfg := proxymw.Config{ + EnableJitter: false, + ClientTimeout: time.Second, + } + + ctx := context.Background() + proxies := []string{"/test-proxy"} + passthroughs := []string{"/test-passthrough"} + routes, err := proxymw.NewRoutes(ctx, cfg, proxies, passthroughs, upstreamURL) + if err != nil { + t.Fatalf("Failed to create routes: %v", err) + } + + testServer := httptest.NewServer(routes) + defer testServer.Close() + + testCases := []struct { + name string + path string + expectedStatus int + }{ + { + name: "Health Check", + path: "/healthz", + expectedStatus: http.StatusOK, + }, + { + name: "Passthrough Path", + path: "/test-proxy", + expectedStatus: http.StatusOK, + }, + { + name: "Passthrough Path", + path: "/test-passthrough", + expectedStatus: http.StatusOK, + }, + { + name: "Not a passthrough", + path: "/non-passthrough", + expectedStatus: http.StatusNotFound, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + u := testServer.URL + tt.path + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, http.NoBody) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + }) + } +} diff --git a/proxyutil/flag.go b/proxyutil/flag.go new file mode 100644 index 0000000..d5eb364 --- /dev/null +++ b/proxyutil/flag.go @@ -0,0 +1,279 @@ +// Package proxyutil handles parsing logic for proxymw configs +package proxyutil + +import ( + "flag" + "fmt" + "net/url" + "os" + "strconv" + "strings" + "time" + + "gopkg.in/yaml.v3" + + "github.com/kevindweb/throttle-proxy/proxymw" +) + +type Config struct { + // InsecureListenAddress is the address the proxy HTTP server should listen on + InsecureListenAddress string `yaml:"insecure_listen_addr"` + // InternalListenAddress is the address the HTTP server should listen on for pprof and metrics + InternalListenAddress string `yaml:"internal_listen_addr"` + // Upstream is the upstream URL to proxy to + Upstream string `yaml:"upstream"` + // ProxyPaths is the list of paths to throttle with proxy settings + ProxyPaths []string `yaml:"proxy_paths"` + // PassthroughPaths is a list of paths to pass through instead of applying proxy settings + PassthroughPaths []string `yaml:"passthrough_paths"` + ProxyConfig proxymw.Config `yaml:"proxymw_config"` + ReadTimeout string `yaml:"proxy_read_timeout"` + WriteTimeout string `yaml:"proxy_write_timeout"` +} + +type StringSlice []string + +func (s *StringSlice) String() string { + return strings.Join(*s, ",") +} + +func (s *StringSlice) Set(value string) error { + *s = append(*s, value) + return nil +} + +type Float64Slice []float64 + +func (f *Float64Slice) String() string { + values := make([]string, len(*f)) + for i, v := range *f { + values[i] = fmt.Sprintf("%g", v) + } + return strings.Join(values, ",") +} + +func (f *Float64Slice) Set(value string) error { + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *f = append(*f, v) + return nil +} + +func ParseConfigs() (Config, error) { + var ( + insecureListenAddress string + internalListenAddress string + readTimeout string + writeTimeout string + upstream string + proxyPaths string + passthroughPaths string + enableBackpressure bool + backpressureMonitoringURL string + bpQueries StringSlice + bpQueryNames StringSlice + bpWarnThresholds Float64Slice + bpEmergencyThresholds Float64Slice + congestionWindowMin int + congestionWindowMax int + enableJitter bool + jitterDelay time.Duration + enableObserver bool + configFile string + ) + + flagset := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + flagset.StringVar(&configFile, "config-file", "", "Config file to initialize the proxy") + flagset.StringVar( + &insecureListenAddress, "insecure-listen-address", "", + "The address the proxy HTTP server should listen on.", + ) + flagset.StringVar( + &internalListenAddress, "internal-listen-address", "", + "The address the internal HTTP server should listen on to expose metrics about itself.", + ) + flagset.StringVar( + &readTimeout, "proxy-read-timeout", (time.Minute * 5).String(), + "HTTP read timeout duration", + ) + flagset.StringVar( + &writeTimeout, "proxy-write-timeout", (time.Minute * 5).String(), + "HTTP write timeout duration", + ) + flagset.StringVar(&upstream, "upstream", "", "The upstream URL to proxy to.") + flagset.BoolVar(&enableJitter, "enable-jitter", false, "Use the jitter middleware") + flagset.DurationVar( + &jitterDelay, "jitter-delay", 0, + "Random jitter to apply when enabled", + ) + flagset.BoolVar( + &enableBackpressure, "enable-bp", false, + "Use the additive increase multiplicative decrease middleware using backpressure metrics", + ) + flagset.IntVar( + &congestionWindowMin, "bp-min-window", 0, + "Min concurrent queries to passthrough regardless of spikes in backpressure.", + ) + flagset.IntVar( + &congestionWindowMax, "bp-max-window", 0, + "Max concurrent queries to passthrough regardless of backpressure health.", + ) + flagset.StringVar( + &backpressureMonitoringURL, "bp-monitoring-url", "", + "The address on which to read backpressure metrics with PromQL queries.", + ) + flagset.Var( + &bpQueries, "bp-query", + "PromQL that signifies an increase in downstream failure", + ) + flagset.Var( + &bpQueryNames, "bp-query-name", + "Name is an optional human readable field used to emit tagged metrics. "+ + "When unset, operational metrics are omitted. "+ + `When set, read warn_threshold as proxymw_bp_warn_threshold{query_name=""}`, + ) + flagset.Var( + &bpWarnThresholds, "bp-warn", + "Threshold that defines when the system should start backing off", + ) + flagset.Var( + &bpEmergencyThresholds, "bp-emergency", + "Threshold that defines when the system should apply maximum throttling", + ) + flagset.BoolVar( + &enableObserver, "enable-observer", false, + "Collect middleware latency and error metrics", + ) + flagset.StringVar(&proxyPaths, "proxy-paths", "", + "Comma delimited allow list of exact HTTP paths that should be allowed to hit "+ + "the upstream URL without any enforcement.") + flagset.StringVar(&passthroughPaths, "passthrough-paths", "", + "Comma delimited allow list of exact HTTP paths that should be allowed to hit "+ + "the upstream URL without any enforcement.") + if err := flagset.Parse(os.Args[1:]); err != nil { + return Config{}, err + } + + queries, err := parseBackpressureQueries( + bpQueries, bpQueryNames, bpWarnThresholds, bpEmergencyThresholds, + ) + if err != nil { + return Config{}, err + } + + if configFile != "" { + return parseConfigFile(configFile) + } + + proxyPathsList, err := parsePaths(proxyPaths) + if err != nil { + return Config{}, err + } + + passthroughPathsList, err := parsePaths(passthroughPaths) + if err != nil { + return Config{}, err + } + + return Config{ + InsecureListenAddress: insecureListenAddress, + InternalListenAddress: internalListenAddress, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + Upstream: upstream, + ProxyPaths: proxyPathsList, + PassthroughPaths: passthroughPathsList, + ProxyConfig: proxymw.Config{ + EnableJitter: enableJitter, + JitterDelay: jitterDelay, + EnableObserver: enableObserver, + BackpressureConfig: proxymw.BackpressureConfig{ + EnableBackpressure: enableBackpressure, + BackpressureMonitoringURL: backpressureMonitoringURL, + BackpressureQueries: queries, + CongestionWindowMin: congestionWindowMin, + CongestionWindowMax: congestionWindowMax, + }, + }, + }, nil +} + +func parsePaths(paths string) ([]string, error) { + if paths == "" { + return []string{}, nil + } + + pathList := strings.Split(paths, ",") + for _, path := range pathList { + u, err := url.Parse(fmt.Sprintf("http://example.com%v", path)) + if err != nil { + return nil, fmt.Errorf( + "path %q is not a valid URI path, got %v", path, paths, + ) + } + if u.Path != path { + return nil, fmt.Errorf( + "path %q is not a valid URI path, got %v", path, paths, + ) + } + if u.Path == "" || u.Path == "/" { + return nil, fmt.Errorf( + "path %q is not allowed, got %v", u.Path, paths, + ) + } + } + return pathList, nil +} + +func parseBackpressureQueries( + bpQueries, bpQueryNames []string, bpWarnThresholds, bpEmergencyThresholds []float64, +) ([]proxymw.BackpressureQuery, error) { + n := len(bpQueries) + queries := make([]proxymw.BackpressureQuery, n) + if len(bpQueryNames) != n && len(bpQueryNames) != 0 { + return nil, fmt.Errorf("number of backpressure query names should be 0 or %d", n) + } + + if len(bpWarnThresholds) != n { + return nil, fmt.Errorf("expected %d warn thresholds for %d backpressure queries", n, n) + } + + if len(bpEmergencyThresholds) != n { + return nil, fmt.Errorf( + "expected %d emergency thresholds for %d backpressure queries", n, n, + ) + } + + for i, query := range bpQueries { + queryName := "" + if len(bpQueryNames) > 0 { + queryName = bpQueryNames[i] + } + queries[i] = proxymw.BackpressureQuery{ + Name: queryName, + Query: query, + WarningThreshold: bpWarnThresholds[i], + EmergencyThreshold: bpEmergencyThresholds[i], + } + } + return queries, nil +} + +func parseConfigFile(configFile string) (Config, error) { + // nolint:gosec // accept configuration file as input + file, err := os.Open(configFile) + if err != nil { + return Config{}, fmt.Errorf("error opening config file: %v", err) + } + defer file.Close() + + var cfg Config + decoder := yaml.NewDecoder(file) + if err := decoder.Decode(&cfg); err != nil { + return Config{}, fmt.Errorf("error decoding YAML: %v", err) + } + + return cfg, nil +} diff --git a/main_test.go b/proxyutil/flag_test.go similarity index 54% rename from main_test.go rename to proxyutil/flag_test.go index 8246387..daf53b0 100644 --- a/main_test.go +++ b/proxyutil/flag_test.go @@ -1,10 +1,6 @@ -package main +package proxyutil_test import ( - "context" - "net/http" - "net/http/httptest" - "net/url" "os" "testing" "time" @@ -12,103 +8,15 @@ import ( "github.com/stretchr/testify/require" "github.com/kevindweb/throttle-proxy/proxymw" + "github.com/kevindweb/throttle-proxy/proxyutil" ) -func TestNewRoutes(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("upstream response")) - })) - defer upstream.Close() - - upstreamURL, err := url.Parse(upstream.URL) - if err != nil { - t.Fatalf("Failed to parse upstream URL: %v", err) - } - - cfg := proxymw.Config{ - EnableJitter: false, - ClientTimeout: time.Second, - } - - ctx := context.Background() - routes, err := NewRoutes(ctx, cfg, []string{"/test-passthrough"}, upstreamURL) - if err != nil { - t.Fatalf("Failed to create routes: %v", err) - } - - testServer := httptest.NewServer(routes) - defer testServer.Close() - - testCases := []struct { - name string - path string - expectedStatus int - }{ - { - name: "Health Check", - path: "/healthz", - expectedStatus: http.StatusOK, - }, - { - name: "Passthrough Path", - path: "/test-passthrough", - expectedStatus: http.StatusOK, - }, - { - name: "Prometheus Query", - path: "/api/v1/query", - expectedStatus: http.StatusOK, - }, - { - name: "Query Range", - path: "/api/v1/query_range", - expectedStatus: http.StatusOK, - }, - { - name: "Not a passthrough", - path: "/non-passthrough", - expectedStatus: http.StatusNotFound, - }, - } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - u := testServer.URL + tt.path - ctx := context.Background() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, http.NoBody) - require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - - defer resp.Body.Close() - if resp.StatusCode != tt.expectedStatus { - t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) - } - }) - } -} - -func TestInvalidJitterConfig(t *testing.T) { - upstream, err := url.Parse("http://google.com") - require.NoError(t, err) - - ctx := context.Background() - cfg := proxymw.Config{ - EnableJitter: true, - JitterDelay: 0, - } - - routes, err := NewRoutes(ctx, cfg, []string{}, upstream) - require.ErrorAs(t, err, &proxymw.ErrJitterDelayRequired) - require.Nil(t, routes) -} - func TestParseConfig(t *testing.T) { for _, tt := range []struct { name string args []string wantErr bool - cfg Config + cfg proxyutil.Config }{ { name: "default config flags", @@ -118,12 +26,13 @@ func TestParseConfig(t *testing.T) { "--insecure-listen-address", ":8080", }, wantErr: false, - cfg: Config{ - Upstream: "http://example.com", - InsecureListenAddress: ":8080", - ReadTimeout: (time.Minute * 5).String(), - WriteTimeout: (time.Minute * 5).String(), - UnsafePassthroughPaths: []string{}, + cfg: proxyutil.Config{ + Upstream: "http://example.com", + InsecureListenAddress: ":8080", + ReadTimeout: (time.Minute * 5).String(), + WriteTimeout: (time.Minute * 5).String(), + ProxyPaths: []string{}, + PassthroughPaths: []string{}, ProxyConfig: proxymw.Config{ EnableJitter: false, EnableObserver: false, @@ -141,7 +50,8 @@ func TestParseConfig(t *testing.T) { "--upstream", "http://example.com", "--insecure-listen-address", ":8080", "--internal-listen-address", ":9090", - "--unsafe-passthrough-paths", "/health,/metrics", + "--proxy-paths", "/api/v2", + "--passthrough-paths", "/health,/metrics", "--proxy-read-timeout", "2m", "--proxy-write-timeout", "3m", "--enable-observer=true", @@ -162,13 +72,14 @@ func TestParseConfig(t *testing.T) { "--enable-observer", }, wantErr: false, - cfg: Config{ - Upstream: "http://example.com", - UnsafePassthroughPaths: []string{"/health", "/metrics"}, - InsecureListenAddress: ":8080", - InternalListenAddress: ":9090", - ReadTimeout: "2m", - WriteTimeout: "3m", + cfg: proxyutil.Config{ + Upstream: "http://example.com", + ProxyPaths: []string{"/api/v2"}, + PassthroughPaths: []string{"/health", "/metrics"}, + InsecureListenAddress: ":8080", + InternalListenAddress: ":9090", + ReadTimeout: "2m", + WriteTimeout: "3m", ProxyConfig: proxymw.Config{ EnableJitter: true, JitterDelay: time.Millisecond * 100, @@ -202,7 +113,17 @@ func TestParseConfig(t *testing.T) { "test-program", "--upstream", "http://example.com", "--insecure-listen-address", ":8080", - "--unsafe-passthrough-paths", ",,,", + "--passthrough-paths", ",,,", + }, + wantErr: true, + }, + { + name: "empty proxy path", + args: []string{ + "test-program", + "--upstream", "http://example.com", + "--insecure-listen-address", ":8080", + "--proxy-paths", ",,,", }, wantErr: true, }, @@ -212,7 +133,25 @@ func TestParseConfig(t *testing.T) { "test-program", "--upstream", "http://example.com", "--insecure-listen-address", ":8080", - "--unsafe-passthrough-paths", "invalid path", + "--passthrough-paths", "invalid path", + }, + wantErr: true, + }, + { + name: "invalid query names", + args: []string{ + "test-program", + "--upstream", "http://example.com", + "--insecure-listen-address", ":8080", + "--passthrough-paths", "/api", + "--enable-bp", + "--bp-query", "up{job='prometheus'} == 0", + "--bp-query-name", "instances_down", + "--bp-emergency", "0.5", + "--bp-warn", "0.7", + "--bp-query", "up{job='prometheus'} == 1", + "--bp-emergency", "0.5", + "--bp-warn", "0.7", }, wantErr: true, }, @@ -234,19 +173,55 @@ func TestParseConfig(t *testing.T) { "test-program", "--upstream", "http://example.com", "--insecure-listen-address", ":8080", - "--unsafe-passthrough-paths", "/api", + "--passthrough-paths", "/api", "--enable-bp", "--bp-query", "up{job='prometheus'} == 0", "--bp-emergency", "0.5", }, wantErr: true, }, + { + name: "simple config file", + args: []string{ + "test-program", + "--config-file", "testdata/simple.yaml", + }, + cfg: proxyutil.Config{ + Upstream: "http://localhost:9095", + PassthroughPaths: []string{"/api/v2"}, + InsecureListenAddress: "0.0.0.0:7777", + InternalListenAddress: "0.0.0.0:7776", + ReadTimeout: "5s", + WriteTimeout: "5s", + ProxyConfig: proxymw.Config{ + EnableJitter: true, + JitterDelay: time.Second * 5, + EnableObserver: true, + }, + }, + }, + { + name: "invalid config file", + args: []string{ + "test-program", + "--config-file", "testdata/invalid.yaml", + }, + wantErr: true, + }, + { + name: "config file does not exist", + args: []string{ + "test-program", + "--config-file", "testdata/nonexistent.yaml", + }, + wantErr: true, + }, } { t.Run(tt.name, func(t *testing.T) { oldArgs := os.Args defer func() { os.Args = oldArgs }() os.Args = tt.args - cfg, err := parseConfigs() + cfg, err := proxyutil.ParseConfigs() require.Equal(t, err != nil, tt.wantErr) require.Equal(t, cfg, tt.cfg) }) diff --git a/proxyutil/testdata/invalid.yaml b/proxyutil/testdata/invalid.yaml new file mode 100644 index 0000000..b82aa76 --- /dev/null +++ b/proxyutil/testdata/invalid.yaml @@ -0,0 +1,2 @@ +invalid +valid: false diff --git a/proxyutil/testdata/simple.yaml b/proxyutil/testdata/simple.yaml new file mode 100644 index 0000000..e88ad92 --- /dev/null +++ b/proxyutil/testdata/simple.yaml @@ -0,0 +1,11 @@ +upstream: http://localhost:9095 +insecure_listen_addr: 0.0.0.0:7777 +internal_listen_addr: 0.0.0.0:7776 +passthrough_paths: + - /api/v2 +proxy_read_timeout: 5s +proxy_write_timeout: 5s +proxymw_config: + enable_jitter: true + jitter_delay: 5s + enable_observer: true diff --git a/sandbox/traffic.py b/scripts/traffic_generator.py similarity index 100% rename from sandbox/traffic.py rename to scripts/traffic_generator.py diff --git a/throttle-proxy b/throttle-proxy index f5d0a98..2b80e0e 100755 Binary files a/throttle-proxy and b/throttle-proxy differ