diff --git a/benchmarks_test.go b/benchmarks_test.go index c06b897..77241c8 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -121,7 +121,7 @@ func TestBenchmark_Cleanup(t *testing.T) { MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res) } { - res := bench.Stats(time.Minute) + res := bench.Stats(time.Minute - time.Second) t.Logf("%+v", res) assert.Equal(t, BenchmarkStats{Requests: 60, RequestsSec: 1, AverageRespTime: 50000, MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res) diff --git a/onlyfrom.go b/onlyfrom.go index eb24bef..61fbf06 100644 --- a/onlyfrom.go +++ b/onlyfrom.go @@ -5,18 +5,23 @@ import ( "net" "net/http" "strings" + + "github.com/go-pkgz/rest/realip" ) // OnlyFrom middleware allows access for limited list of source IPs. // Such IPs can be defined as complete ip (like 192.168.1.12), prefix (129.168.) or CIDR (192.168.0.0/16) func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - + if len(onlyIps) == 0 { + // no restrictions if no ips defined + h.ServeHTTP(w, r) + return + } matched, ip := matchSourceIP(r, onlyIps) if matched { + // matched ip - allow h.ServeHTTP(w, r) return } @@ -30,19 +35,10 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { // matchSourceIP returns true if request's ip matches any of ips func matchSourceIP(r *http.Request, ips []string) (result bool, match string) { - - // try X-Real-IP first then fail back to X-Forwarded-For and finally to RemoteAddr - ip := r.Header.Get("X-Real-IP") - if ip == "" { - ip = strings.Split(r.Header.Get("X-Forwarded-For"), ", ")[0] - } - if ip == "" { - ip = r.Header.Get("RemoteAddr") + ip, err := realip.Get(r) + if err != nil { + return false, "" // we can't get ip, so no match } - if ip == "" { - ip = strings.Split(r.RemoteAddr, ":")[0] - } - // check for ip prefix or CIDR for _, exclIP := range ips { if _, cidrnet, err := net.ParseCIDR(exclIP); err == nil { diff --git a/onlyfrom_test.go b/onlyfrom_test.go index 279e5ba..d83612f 100644 --- a/onlyfrom_test.go +++ b/onlyfrom_test.go @@ -10,8 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestOnlyFromAllowed(t *testing.T) { - +func TestOnlyFromAllowedIP(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("blah blah")) require.NoError(t, err) @@ -30,7 +29,6 @@ func TestOnlyFromAllowed(t *testing.T) { } func TestOnlyFromAllowedHeaders(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("blah blah")) require.NoError(t, err) @@ -48,26 +46,32 @@ func TestOnlyFromAllowedHeaders(t *testing.T) { } client := http.Client{} - req, err := reqWithHeader("X-Real-IP") - require.NoError(t, err) - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) + t.Run("X-Real-IP", func(t *testing.T) { + req, err := reqWithHeader("X-Real-IP") + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + }) - req, err = reqWithHeader("X-Forwarded-For") - require.NoError(t, err) - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) + t.Run("X-Forwarded-For", func(t *testing.T) { + req, err := reqWithHeader("X-Forwarded-For") + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + }) - req, err = reqWithHeader("RemoteAddr") - require.NoError(t, err) - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) + t.Run("X-Forwarded-For and X-Real-IP missing", func(t *testing.T) { + req, err := reqWithHeader("blah") + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 403, resp.StatusCode) + }) } func TestOnlyFromAllowedCIDR(t *testing.T) { diff --git a/realip/real.go b/realip/real.go index c30e3d4..58835c7 100644 --- a/realip/real.go +++ b/realip/real.go @@ -25,7 +25,7 @@ var privateRanges = []ipRange{ // Get returns real ip from the given request func Get(r *http.Request) (string, error) { - + var firstIP string for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { addresses := strings.Split(r.Header.Get(h), ",") // march from right to left until we get a public address @@ -33,6 +33,9 @@ func Get(r *http.Request) (string, error) { for i := len(addresses) - 1; i >= 0; i-- { ip := strings.TrimSpace(addresses[i]) realIP := net.ParseIP(ip) + if firstIP == "" && realIP != nil { + firstIP = ip + } if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) { continue } @@ -40,9 +43,9 @@ func Get(r *http.Request) (string, error) { } } - // X-Forwarded-For header set but parsing failed above - if r.Header.Get("X-Forwarded-For") != "" { - return "", fmt.Errorf("no valid ip found") + // if we cannot find a public address in X-Forwarded-For or X-Real-IP headers, fallback to first ip + if firstIP != "" { + return firstIP, nil } // get IP from RemoteAddr diff --git a/realip/real_test.go b/realip/real_test.go index 8a56502..1d76228 100644 --- a/realip/real_test.go +++ b/realip/real_test.go @@ -12,7 +12,7 @@ import ( ) func TestGetFromHeaders(t *testing.T) { - { + t.Run("single X-Real-IP", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -20,8 +20,8 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "8.8.8.8", adr) - } - { + }) + t.Run("X-Forwarded-For last public", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -29,8 +29,8 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "30.30.30.1", adr) - } - { + }) + t.Run("X-Forwarded-For last private", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -38,8 +38,17 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "1.1.1.2", adr) - } - { + }) + t.Run("X-Forwarded-For all private", func(t *testing.T) { + req, err := http.NewRequest("GET", "/something", http.NoBody) + assert.NoError(t, err) + req.Header.Add("Something", "1234567") + req.Header.Add("X-Forwarded-For", "192.168.1.1,10.0.0.65") + adr, err := Get(req) + require.NoError(t, err) + assert.Equal(t, "10.0.0.65", adr) + }) + t.Run("X-Forwarded-For public, X-Real-IP private", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -48,8 +57,8 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "30.30.30.1", adr) - } - { + }) + t.Run("X-Forwarded-For and X-Real-IP public", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -58,8 +67,8 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "30.30.30.1", adr) - } - { + }) + t.Run("X-Forwarded-For private and X-Real-IP public]", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) req.Header.Add("Something", "1234567") @@ -68,14 +77,22 @@ func TestGetFromHeaders(t *testing.T) { adr, err := Get(req) require.NoError(t, err) assert.Equal(t, "8.8.8.8", adr) - } - { + }) + t.Run("RemoteAddr fallback", func(t *testing.T) { + req, err := http.NewRequest("GET", "/something", http.NoBody) + assert.NoError(t, err) + req.RemoteAddr = "192.0.2.1:1234" + adr, err := Get(req) + require.NoError(t, err) + assert.Equal(t, "192.0.2.1", adr) + }) + t.Run("X-Forwarded-For and X-Real-IP missing, no RemoteAddr either", func(t *testing.T) { req, err := http.NewRequest("GET", "/something", http.NoBody) assert.NoError(t, err) ip, err := Get(req) assert.Error(t, err) assert.Equal(t, "", ip) - } + }) } func TestGetFromRemoteAddr(t *testing.T) {