diff --git a/middleware/client_ip.go b/middleware/client_ip.go new file mode 100644 index 00000000..d8b76c3d --- /dev/null +++ b/middleware/client_ip.go @@ -0,0 +1,185 @@ +package middleware + +import ( + "context" + "net" + "net/http" + "net/netip" + "strings" +) + +var ( + // clientIPCtxKey is the context key used to store the client IP address. + clientIPCtxKey = &contextKey{"clientIP"} +) + +// ClientIPFromHeader parses the client IP address from a specified HTTP header +// (e.g., X-Real-IP, CF-Connecting-IP) and injects it into the request context +// if it is not already set. The parsed IP address can be retrieved using GetClientIP(). +// +// The middleware validates the IP address to ignore loopback, private, and unspecified addresses. +// +// ### Important Notice: +// - Use this middleware only when your infrastructure sets a trusted header containing the client IP. +// - If the specified header is not securely set by your infrastructure, malicious clients could spoof it. +// +// Example trusted headers: +// - "X-Real-IP" - Nginx (ngx_http_realip_module) +// - "X-Client-IP" - Apache (mod_remoteip) +// - "CF-Connecting-IP" - Cloudflare +// - "True-Client-IP" - Akamai, Cloudflare Enterprise +// - "X-Azure-ClientIP" - Azure Front Door +// - "Fastly-Client-IP" - Fastly +func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Check if the client IP is already set in the context. + if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { + h.ServeHTTP(w, r) + return + } + + // Parse the IP address from the trusted header. + ip, err := netip.ParseAddr(r.Header.Get(trustedHeader)) + if err != nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() { + // Ignore invalid or private IPs. + h.ServeHTTP(w, r) + return + } + + // Store the valid client IP in the context. + ctx = context.WithValue(ctx, clientIPCtxKey, ip) + h.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} + +// ClientIPFromXFFHeader parses the client IP address from the X-Forwarded-For +// header and injects it into the request context if it is not already set. The +// parsed IP address can be retrieved using GetClientIP(). +// +// The middleware traverses the X-Forwarded-For chain (rightmost untrusted IP) +// and excludes loopback, private, unspecified, and trusted IP ranges. +// +// ### Important Notice: +// - Use this middleware only when your infrastructure sets and validates the X-Forwarded-For header. +// - Malicious clients can spoof the header unless a trusted reverse proxy or load balancer sanitizes it. +// +// Parameters: +// - `trustedIPPrefixes`: A list of CIDR prefixes that define trusted proxy IP ranges. +// +// Example trusted IP ranges: +// - "203.0.113.0/24" - Example corporate proxy +// - "198.51.100.0/24" - Example data center or hosting provider +// - "2400:cb00::/32" - Cloudflare IPv6 range +// - "2606:4700::/32" - Cloudflare IPv6 range +// - "192.0.2.0/24" - Example VPN gateway +// +// Note: Private IP ranges (e.g., "10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12") +// are automatically excluded by netip.Addr.IsPrivate() and do not need to be added here. +func ClientIPFromXFFHeader(trustedIPPrefixes ...string) func(http.Handler) http.Handler { + // Pre-parse trusted prefixes. + trustedPrefixes := make([]netip.Prefix, len(trustedIPPrefixes)) + for i, ipRange := range trustedIPPrefixes { + trustedPrefixes[i] = netip.MustParsePrefix(ipRange) + } + + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Check if the client IP is already set in the context. + if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { + h.ServeHTTP(w, r) + return + } + + // Parse and split the X-Forwarded-For header(s). + xff := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",") + nextValue: + for i := len(xff) - 1; i >= 0; i-- { + ip, err := netip.ParseAddr(strings.TrimSpace(xff[i])) + if err != nil { + continue + } + + // Ignore loopback, private, or unspecified addresses. + if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() { + continue + } + + // Ignore trusted IPs within the given ranges. + for _, prefix := range trustedPrefixes { + if prefix.Contains(ip) { + continue nextValue + } + } + + // Store the valid client IP in the context. + ctx = context.WithValue(ctx, clientIPCtxKey, ip) + h.ServeHTTP(w, r.WithContext(ctx)) + return + } + + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +// ClientIPFromRemoteAddr extracts the client IP address from the RemoteAddr +// field of the HTTP request and injects it into the request context if it is +// not already set. The parsed IP address can be retrieved using GetClientIP(). +// +// The middleware ignores invalid or private IPs. +// +// ### Use Case: +// This middleware is useful when the client IP cannot be determined from headers +// such as X-Forwarded-For or X-Real-IP, and you need to fall back to RemoteAddr. +func ClientIPFromRemoteAddr(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Check if the client IP is already set in the context. + if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok { + h.ServeHTTP(w, r) + return + } + + // Extract the IP from request RemoteAddr. + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + h.ServeHTTP(w, r) + return + } + + ip, err := netip.ParseAddr(host) + if err != nil { + h.ServeHTTP(w, r) + return + } + + // Store the valid client IP in the context. + ctx = context.WithValue(ctx, clientIPCtxKey, ip) + h.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// GetClientIP retrieves the client IP address from the given context. +// The IP address is set by one of the following middlewares: +// - ClientIPFromHeader +// - ClientIPFromXFFHeader +// - ClientIPFromRemoteAddr +// +// Returns an empty string if no valid IP is found. +func GetClientIP(ctx context.Context) string { + ip, ok := ctx.Value(clientIPCtxKey).(netip.Addr) + if !ok || !ip.IsValid() { + return "" + } + return ip.String() +} diff --git a/middleware/client_ip_test.go b/middleware/client_ip_test.go new file mode 100644 index 00000000..2c377eb8 --- /dev/null +++ b/middleware/client_ip_test.go @@ -0,0 +1,141 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestClientIPFromHeader(t *testing.T) { + tt := []struct { + name string + in string + out string + }{ + // Empty header. + {name: "empty", in: "", out: ""}, + + // Valid X-Real-IP header values. + {name: "valid_ipv4", in: "100.100.100.100", out: "100.100.100.100"}, + {name: "valid_ipv4", in: "178.25.203.2", out: "178.25.203.2"}, + {name: "valid_ipv6_lower", in: "2345:0425:2ca1:0000:0000:0567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"}, + {name: "valid_ipv6_upper", in: "2345:0425:2CA1:0000:0000:0567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"}, + {name: "valid_ipv6_lower_short", in: "2345:425:2ca1::567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"}, + {name: "valid_ipv6_upper_short", in: "2345:425:2CA1::567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"}, + + // Invalid X-Real-IP header values. + {name: "invalid_ip", in: "invalid", out: ""}, + {name: "invalid_ip_with_port", in: "100.100.100.100:80", out: ""}, + {name: "invalid_multiple_ips", in: "100.100.100.100;100.100.100.101;100.100.100.102", out: ""}, + {name: "invalid_loopback", in: "127.0.0.1", out: ""}, + {name: "invalid_zeroes", in: "0.0.0.0", out: ""}, + {name: "invalid_loopback", in: "127.0.0.1", out: ""}, + {name: "invalid_private_ipv4_1", in: "192.168.0.1", out: ""}, + {name: "invalid_private_ipv4_2", in: "192.168.10.12", out: ""}, + {name: "invalid_private_ipv4_3", in: "172.16.0.0", out: ""}, + {name: "invalid_private_ipv4_4", in: "172.25.203.2", out: ""}, + {name: "invalid_private_ipv4_5", in: "10.0.0.0", out: ""}, + {name: "invalid_private_ipv4_6", in: "10.0.1.10", out: ""}, + {name: "invalid_private_ipv6_1", in: "fc00::1", out: ""}, + {name: "invalid_private_ipv6_2", in: "fc00:0425:2ca1:0000:0000:0567:5673:23b5", out: ""}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add("X-Real-IP", tc.in) + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(ClientIPFromHeader("X-Real-IP")) + + var clientIP string + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + clientIP = GetClientIP(r.Context()) + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Errorf("Response Code should be 200") + } + + if clientIP != tc.out { + t.Errorf("expected %v, got %v", tc.out, clientIP) + } + }) + } +} + +func TestClientIPFromXFFHeader(t *testing.T) { + tt := []struct { + name string + xff []string + out string + }{ + {name: "empty", xff: []string{""}, out: ""}, + + {name: "", xff: []string{"100.100.100.100"}, out: "100.100.100.100"}, + {name: "", xff: []string{"100.100.100.100, 200.200.200.200"}, out: "200.200.200.200"}, + {name: "", xff: []string{"100.100.100.100,200.200.200.200"}, out: "200.200.200.200"}, + {name: "", xff: []string{"100.100.100.100", "200.200.200.200"}, out: "200.200.200.200"}, + {name: "", xff: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"}, + {name: "", xff: []string{"203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"}, + {name: "", xff: []string{"5.5.5.5, 203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348", "7.7.7.7, 4.4.4.4"}, out: "4.4.4.4"}, + } + + r := chi.NewRouter() + r.Use(ClientIPFromXFFHeader()) + + for _, tc := range tt { + req, _ := http.NewRequest("GET", "/", nil) + for _, v := range tc.xff { + req.Header.Add("X-Forwarded-For", v) + } + + w := httptest.NewRecorder() + + clientIP := "" + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + clientIP = GetClientIP(r.Context()) + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Errorf("Response Code should be 200") + } + + if clientIP != tc.out { + t.Errorf("expected %v, got %v", tc.out, clientIP) + } + } +} + +func TestClientIPFromRemoteAddr(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "192.0.2.1:1234" // Simulate the remote address set by http.Server. + + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(ClientIPFromRemoteAddr) + + var clientIP string + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + clientIP = GetClientIP(r.Context()) + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Errorf("Response Code should be 200") + } + + expected := "192.0.2.1" + if clientIP != expected { + t.Errorf("expected %v, got %v", expected, clientIP) + } +}