From 46f6bf3d0278037ce3c084ce7ded6ee0127f654d Mon Sep 17 00:00:00 2001 From: Kevin Deems Date: Mon, 30 Dec 2024 09:01:10 -0500 Subject: [PATCH] Clean up flag parsing --- README.md | 2 +- main.go | 30 +--- proxymw/backpressure.go | 34 +++++ proxyutil/flag.go | 295 +++++++++++----------------------------- proxyutil/flag_test.go | 16 +-- 5 files changed, 127 insertions(+), 250 deletions(-) diff --git a/README.md b/README.md index 7b8193b..f921f82 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ Build the docker-compose stack ``` make all docker compose down -docker-compose up --build +docker compose up --build ``` - Generate fake traffic with `./scripts/traffic_generator.py` diff --git a/main.go b/main.go index 318233e..58b55e6 100644 --- a/main.go +++ b/main.go @@ -65,18 +65,8 @@ func main() { } func setupInsecureServer(ctx context.Context, cfg proxyutil.Config) (*http.Server, error) { - readTimeout, err := time.ParseDuration(cfg.ReadTimeout) - if err != nil { - return nil, fmt.Errorf("error parsing read timeout: %v", err) - } - - writeTimeout, err := time.ParseDuration(cfg.WriteTimeout) - if err != nil { - return nil, fmt.Errorf("error parsing write timeout: %v", err) - } - if cfg.ProxyConfig.ClientTimeout == 0 { - cfg.ProxyConfig.ClientTimeout = 2 * readTimeout + cfg.ProxyConfig.ClientTimeout = 2 * cfg.ReadTimeout } routes, err := proxyhttp.NewRoutes(ctx, cfg) @@ -94,8 +84,8 @@ func setupInsecureServer(ctx context.Context, cfg proxyutil.Config) (*http.Serve srv := &http.Server{ Handler: mux, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, } go func() { @@ -129,20 +119,10 @@ func setupInternalServer(cfg proxyutil.Config) (*http.Server, error) { return nil, fmt.Errorf("failed to listen on internal address: %v", err) } - readTimeout, err := time.ParseDuration(cfg.ReadTimeout) - if err != nil { - return nil, fmt.Errorf("error parsing read timeout: %v", err) - } - - writeTimeout, err := time.ParseDuration(cfg.WriteTimeout) - if err != nil { - return nil, fmt.Errorf("error parsing write timeout: %v", err) - } - srv := &http.Server{ Handler: h, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, } go func() { diff --git a/proxymw/backpressure.go b/proxymw/backpressure.go index 42f13b5..9df8ed2 100644 --- a/proxymw/backpressure.go +++ b/proxymw/backpressure.go @@ -124,6 +124,40 @@ type BackpressureConfig struct { CongestionWindowMax int `yaml:"congestion_window_max"` } +func ParseBackpressureQueries( + bpQueries, bpQueryNames []string, bpWarnThresholds, bpEmergencyThresholds []float64, +) ([]BackpressureQuery, error) { + n := len(bpQueries) + queries := make([]BackpressureQuery, n) + if len(bpQueryNames) != n && len(bpQueryNames) != 0 { + return nil, fmt.Errorf("number of backpressure query names should be 0 or %d", n) + } + + if len(bpWarnThresholds) != n { + return nil, fmt.Errorf("expected %d warn thresholds for %d backpressure queries", n, n) + } + + if len(bpEmergencyThresholds) != n { + return nil, fmt.Errorf( + "expected %d emergency thresholds for %d backpressure queries", n, n, + ) + } + + for i, query := range bpQueries { + queryName := "" + if len(bpQueryNames) > 0 { + queryName = bpQueryNames[i] + } + queries[i] = BackpressureQuery{ + Name: queryName, + Query: query, + WarningThreshold: bpWarnThresholds[i], + EmergencyThreshold: bpEmergencyThresholds[i], + } + } + return queries, nil +} + func (c BackpressureConfig) Validate() error { if !c.EnableBackpressure { return nil diff --git a/proxyutil/flag.go b/proxyutil/flag.go index dec5eed..02de3f4 100644 --- a/proxyutil/flag.go +++ b/proxyutil/flag.go @@ -16,19 +16,14 @@ import ( ) type Config struct { - // InsecureListenAddress is the address the proxy HTTP server should listen on - InsecureListenAddress string `yaml:"insecure_listen_addr"` - // InternalListenAddress is the address the HTTP server should listen on for pprof and metrics - InternalListenAddress string `yaml:"internal_listen_addr"` - // Upstream is the upstream URL to proxy to - Upstream string `yaml:"upstream"` - // ProxyPaths is the list of paths to throttle with proxy settings - ProxyPaths []string `yaml:"proxy_paths"` - // PassthroughPaths is a list of paths to pass through instead of applying proxy settings - PassthroughPaths []string `yaml:"passthrough_paths"` - ProxyConfig proxymw.Config `yaml:"proxymw_config"` - ReadTimeout string `yaml:"proxy_read_timeout"` - WriteTimeout string `yaml:"proxy_write_timeout"` + InsecureListenAddress string `yaml:"insecure_listen_addr"` + InternalListenAddress string `yaml:"internal_listen_addr"` + Upstream string `yaml:"upstream"` + ProxyPaths []string `yaml:"proxy_paths"` + PassthroughPaths []string `yaml:"passthrough_paths"` + ProxyConfig proxymw.Config `yaml:"proxymw_config"` + ReadTimeout time.Duration `yaml:"proxy_read_timeout"` + WriteTimeout time.Duration `yaml:"proxy_write_timeout"` } type StringSlice []string @@ -62,106 +57,51 @@ func (f *Float64Slice) Set(value string) error { } func ParseConfigFlags() (Config, error) { - var ( - insecureListenAddress string - internalListenAddress string - readTimeout string - writeTimeout string - upstream string - proxyPaths string - passthroughPaths string - enableBackpressure bool - backpressureMonitoringURL string - bpQueries StringSlice - bpQueryNames StringSlice - bpWarnThresholds Float64Slice - bpEmergencyThresholds Float64Slice - congestionWindowMin int - congestionWindowMax int - enableCriticality bool - enableJitter bool - jitterDelay time.Duration - enableObserver bool - configFile string - ) - - flagset := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - flagset.StringVar(&configFile, "config-file", "", "Config file to initialize the proxy") - flagset.StringVar( - &insecureListenAddress, "insecure-listen-address", "", - "The address the proxy HTTP server should listen on.", - ) - flagset.StringVar( - &internalListenAddress, "internal-listen-address", "", - "The address the internal HTTP server should listen on to expose metrics about itself.", - ) - flagset.StringVar( - &readTimeout, "proxy-read-timeout", (time.Minute * 5).String(), - "HTTP read timeout duration", - ) - flagset.StringVar( - &writeTimeout, "proxy-write-timeout", (time.Minute * 5).String(), - "HTTP write timeout duration", - ) - flagset.StringVar(&upstream, "upstream", "", "The upstream URL to proxy to.") - flagset.BoolVar(&enableCriticality, "enable-criticality", false, "Read criticality headers") - flagset.BoolVar(&enableJitter, "enable-jitter", false, "Use the jitter middleware") - flagset.DurationVar( - &jitterDelay, "jitter-delay", 0, - "Random jitter to apply when enabled", - ) - flagset.BoolVar( - &enableBackpressure, "enable-bp", false, - "Use the additive increase multiplicative decrease middleware using backpressure metrics", - ) - flagset.IntVar( - &congestionWindowMin, "bp-min-window", 0, - "Min concurrent queries to passthrough regardless of spikes in backpressure.", - ) - flagset.IntVar( - &congestionWindowMax, "bp-max-window", 0, - "Max concurrent queries to passthrough regardless of backpressure health.", - ) - flagset.StringVar( - &backpressureMonitoringURL, "bp-monitoring-url", "", - "The address on which to read backpressure metrics with PromQL queries.", - ) - flagset.Var( - &bpQueries, "bp-query", - "PromQL that signifies an increase in downstream failure", - ) - flagset.Var( - &bpQueryNames, "bp-query-name", - "Name is an optional human readable field used to emit tagged metrics. "+ - "When unset, operational metrics are omitted. "+ - `When set, read warn_threshold as proxymw_bp_warn_threshold{query_name=""}`, - ) - flagset.Var( - &bpWarnThresholds, "bp-warn", - "Threshold that defines when the system should start backing off", - ) - flagset.Var( - &bpEmergencyThresholds, "bp-emergency", - "Threshold that defines when the system should apply maximum throttling", - ) - flagset.BoolVar( - &enableObserver, "enable-observer", false, - "Collect middleware latency and error metrics", - ) - flagset.StringVar(&proxyPaths, "proxy-paths", "", - "Comma delimited allow list of exact HTTP paths that should be allowed to hit "+ - "the upstream URL without any enforcement.") - flagset.StringVar(&passthroughPaths, "passthrough-paths", "", - "Comma delimited allow list of exact HTTP paths that should be allowed to hit "+ - "the upstream URL without any enforcement.") - if err := flagset.Parse(os.Args[1:]); err != nil { - return Config{}, err - } + cfg := Config{} + flags := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - queries, err := ParseBackpressureQueries( - bpQueries, bpQueryNames, bpWarnThresholds, bpEmergencyThresholds, - ) - if err != nil { + var ( + bpQueries StringSlice + bpQueryNames StringSlice + bpWarnThresholds Float64Slice + bpEmergencyThresholds Float64Slice + proxyPaths string + passthroughPaths string + configFile string + ) + + // Config file + flags.StringVar(&configFile, "config-file", "", "Path to proxy configuration file") + + // Server settings + flags.StringVar(&cfg.InsecureListenAddress, "insecure-listen-address", "", "HTTP proxy server listen address") + flags.StringVar(&cfg.InternalListenAddress, "internal-listen-address", "", "Internal metrics server listen address") + flags.DurationVar(&cfg.ReadTimeout, "proxy-read-timeout", 5*time.Minute, "HTTP read timeout") + flags.DurationVar(&cfg.WriteTimeout, "proxy-write-timeout", 5*time.Minute, "HTTP write timeout") + flags.StringVar(&cfg.Upstream, "upstream", "", "Upstream URL to proxy to") + + // Feature flags + flags.BoolVar(&cfg.ProxyConfig.EnableCriticality, "enable-criticality", false, "Enable criticality header processing") + flags.BoolVar(&cfg.ProxyConfig.EnableJitter, "enable-jitter", false, "Enable request jitter") + flags.DurationVar(&cfg.ProxyConfig.JitterDelay, "jitter-delay", 0, "Random jitter delay duration") + flags.BoolVar(&cfg.ProxyConfig.EnableObserver, "enable-observer", false, "Enable middleware metrics collection") + + // Backpressure settings + bp := &cfg.ProxyConfig.BackpressureConfig + flags.BoolVar(&bp.EnableBackpressure, "enable-bp", false, "Enable backpressure-based throttling") + flags.IntVar(&bp.CongestionWindowMin, "bp-min-window", 0, "Minimum concurrent query limit") + flags.IntVar(&bp.CongestionWindowMax, "bp-max-window", 0, "Maximum concurrent query limit") + flags.StringVar(&bp.BackpressureMonitoringURL, "bp-monitoring-url", "", "Backpressure metrics endpoint") + flags.Var(&bpQueries, "bp-query", "PromQL query for downstream failures") + flags.Var(&bpQueryNames, "bp-query-name", "Human-readable name for backpressure query") + flags.Var(&bpWarnThresholds, "bp-warn", "Warning threshold for throttling") + flags.Var(&bpEmergencyThresholds, "bp-emergency", "Emergency threshold for maximum throttling") + + // Path settings + flags.StringVar(&proxyPaths, "proxy-paths", "", "Comma-separated list of paths to proxy") + flags.StringVar(&passthroughPaths, "passthrough-paths", "", "Comma-separated list of paths to pass through") + + if err := flags.Parse(os.Args[1:]); err != nil { return Config{}, err } @@ -169,74 +109,43 @@ func ParseConfigFlags() (Config, error) { return ParseConfigFile(configFile) } - proxyPathsList, err := parsePaths(proxyPaths) - if err != nil { + var err error + if bp.BackpressureQueries, err = proxymw.ParseBackpressureQueries( + bpQueries, bpQueryNames, bpWarnThresholds, bpEmergencyThresholds, + ); err != nil { return Config{}, err } - - passthroughPathsList, err := parsePaths(passthroughPaths) - if err != nil { + if cfg.ProxyPaths, err = parsePaths(proxyPaths); err != nil { + return Config{}, err + } + if cfg.PassthroughPaths, err = parsePaths(passthroughPaths); err != nil { return Config{}, err } - return Config{ - InsecureListenAddress: insecureListenAddress, - InternalListenAddress: internalListenAddress, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - Upstream: upstream, - ProxyPaths: proxyPathsList, - PassthroughPaths: passthroughPathsList, - ProxyConfig: proxymw.Config{ - EnableCriticality: enableCriticality, - EnableJitter: enableJitter, - JitterDelay: jitterDelay, - EnableObserver: enableObserver, - BackpressureConfig: proxymw.BackpressureConfig{ - EnableBackpressure: enableBackpressure, - BackpressureMonitoringURL: backpressureMonitoringURL, - BackpressureQueries: queries, - CongestionWindowMin: congestionWindowMin, - CongestionWindowMax: congestionWindowMax, - }, - }, - }, nil + return cfg, nil } func ParseConfigEnvironment() (Config, error) { - u := os.Getenv("UPSTREAM") - proxyPaths := os.Getenv("PROXY_PATHS") - passthroughPaths := os.Getenv("PASSTHROUGH_PATHS") + cfg := Config{} + var err error - proxyPathsList, err := parsePaths(proxyPaths) - if err != nil { + cfg.Upstream = os.Getenv("UPSTREAM") + + if cfg.ProxyPaths, err = parsePaths(os.Getenv("PROXY_PATHS")); err != nil { return Config{}, err } - - passthroughPathsList, err := parsePaths(passthroughPaths) - if err != nil { + if cfg.PassthroughPaths, err = parsePaths(os.Getenv("PASSTHROUGH_PATHS")); err != nil { return Config{}, err } - enableJitter, err := getBoolEnv("PROXYMW_ENABLE_JITTER") - if err != nil { + if cfg.ProxyConfig.EnableJitter, err = getBoolEnv("PROXYMW_ENABLE_JITTER"); err != nil { return Config{}, err } - - jitterDelay, err := getDurationEnv("PROXYMW_JITTER_DELAY") - if err != nil { + if cfg.ProxyConfig.JitterDelay, err = getDurationEnv("PROXYMW_JITTER_DELAY"); err != nil { return Config{}, err } - return Config{ - Upstream: u, - ProxyPaths: proxyPathsList, - PassthroughPaths: passthroughPathsList, - ProxyConfig: proxymw.Config{ - EnableJitter: enableJitter, - JitterDelay: jitterDelay, - }, - }, nil + return cfg, nil } func getBoolEnv(key string) (bool, error) { @@ -260,62 +169,17 @@ func parsePaths(paths string) ([]string, error) { return []string{}, nil } - pathList := strings.Split(paths, ",") - for _, path := range pathList { - u, err := url.Parse(fmt.Sprintf("http://example.com%v", path)) - if err != nil { - return nil, fmt.Errorf( - "path %q is not a valid URI path, got %v", path, paths, - ) - } - if u.Path != path { - return nil, fmt.Errorf( - "path %q is not a valid URI path, got %v", path, paths, - ) - } - if u.Path == "" || u.Path == "/" { - return nil, fmt.Errorf( - "path %q is not allowed, got %v", u.Path, paths, - ) + pathList := []string{} + for _, path := range strings.Split(paths, ",") { + u, err := url.Parse("http://example.com" + path) + if err != nil || u.Path != path || path == "" || path == "/" { + return nil, fmt.Errorf("invalid path %q in path list %q", path, paths) } + pathList = append(pathList, path) } return pathList, nil } -func ParseBackpressureQueries( - bpQueries, bpQueryNames []string, bpWarnThresholds, bpEmergencyThresholds []float64, -) ([]proxymw.BackpressureQuery, error) { - n := len(bpQueries) - queries := make([]proxymw.BackpressureQuery, n) - if len(bpQueryNames) != n && len(bpQueryNames) != 0 { - return nil, fmt.Errorf("number of backpressure query names should be 0 or %d", n) - } - - if len(bpWarnThresholds) != n { - return nil, fmt.Errorf("expected %d warn thresholds for %d backpressure queries", n, n) - } - - if len(bpEmergencyThresholds) != n { - return nil, fmt.Errorf( - "expected %d emergency thresholds for %d backpressure queries", n, n, - ) - } - - for i, query := range bpQueries { - queryName := "" - if len(bpQueryNames) > 0 { - queryName = bpQueryNames[i] - } - queries[i] = proxymw.BackpressureQuery{ - Name: queryName, - Query: query, - WarningThreshold: bpWarnThresholds[i], - EmergencyThreshold: bpEmergencyThresholds[i], - } - } - return queries, nil -} - func ParseConfigFile(configFile string) (Config, error) { return ParseFile[Config](configFile) } @@ -325,15 +189,14 @@ func ParseProxyConfigFile(configFile string) (proxymw.Config, error) { } func ParseFile[T any](configFile string) (cfg T, err error) { - // nolint:gosec // accept configuration file as input - file, err := os.Open(configFile) + file, err := os.Open(configFile) // nolint:gosec // input configuration file if err != nil { - return cfg, fmt.Errorf("error opening config file: %v", err) + return cfg, fmt.Errorf("error opening config file: %w", err) } defer file.Close() if err := yaml.NewDecoder(file).Decode(&cfg); err != nil { - return cfg, fmt.Errorf("error decoding YAML: %v", err) + return cfg, fmt.Errorf("error decoding YAML: %w", err) } return cfg, nil diff --git a/proxyutil/flag_test.go b/proxyutil/flag_test.go index 9170ce2..fec72e0 100644 --- a/proxyutil/flag_test.go +++ b/proxyutil/flag_test.go @@ -29,8 +29,8 @@ func TestParseConfig(t *testing.T) { cfg: proxyutil.Config{ Upstream: "http://example.com", InsecureListenAddress: ":8080", - ReadTimeout: (time.Minute * 5).String(), - WriteTimeout: (time.Minute * 5).String(), + ReadTimeout: time.Minute * 5, + WriteTimeout: time.Minute * 5, ProxyPaths: []string{}, PassthroughPaths: []string{}, ProxyConfig: proxymw.Config{ @@ -52,8 +52,8 @@ func TestParseConfig(t *testing.T) { "--internal-listen-address", ":9090", "--proxy-paths", "/api/v2", "--passthrough-paths", "/health,/metrics", - "--proxy-read-timeout", "2m", - "--proxy-write-timeout", "3m", + "--proxy-read-timeout", "2m0s", + "--proxy-write-timeout", "3m0s", "--enable-observer=true", "--enable-criticality=true", "--enable-jitter", @@ -79,8 +79,8 @@ func TestParseConfig(t *testing.T) { PassthroughPaths: []string{"/health", "/metrics"}, InsecureListenAddress: ":8080", InternalListenAddress: ":9090", - ReadTimeout: "2m", - WriteTimeout: "3m", + ReadTimeout: 2 * time.Minute, + WriteTimeout: 3 * time.Minute, ProxyConfig: proxymw.Config{ EnableCriticality: true, EnableJitter: true, @@ -193,8 +193,8 @@ func TestParseConfig(t *testing.T) { PassthroughPaths: []string{"/api/v2"}, InsecureListenAddress: "0.0.0.0:7777", InternalListenAddress: "0.0.0.0:7776", - ReadTimeout: "5s", - WriteTimeout: "5s", + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, ProxyConfig: proxymw.Config{ EnableJitter: true, JitterDelay: time.Second * 5,