From f6dfcbe774b43b9ad12e9b477bb34b4412c23452 Mon Sep 17 00:00:00 2001 From: Arun Gopalpuri Date: Thu, 3 Sep 2020 00:39:57 -0700 Subject: [PATCH] bugfix proxy and rewrite, updated test with actual call settings --- middleware/middleware.go | 31 ++++++++++++++++++++++--------- middleware/proxy.go | 22 +++++----------------- middleware/proxy_test.go | 15 +++++++-------- middleware/rewrite.go | 30 ++++-------------------------- middleware/rewrite_test.go | 25 ++++++++++--------------- 5 files changed, 48 insertions(+), 75 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 12260ddb2..60834b505 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "net/url" "regexp" "strconv" "strings" @@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } -//rewritePath sets request url path and raw path -func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { - replacerRawPath := replacer.Replace(target) - replacerPath, err := url.PathUnescape(replacerRawPath) - if err != nil { - return err +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { + for k, v := range rewriteRegex { + replacerRawPath := captureTokens(k, req.URL.EscapedPath()) + if replacerRawPath != nil { + replacerPath := captureTokens(k, req.URL.Path) + req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v) + } } - req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath - return nil } // DefaultSkipper returns false which processes the middleware. diff --git a/middleware/proxy.go b/middleware/proxy.go index cd50b76a1..1b972eb16 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "regexp" - "strings" "sync" "sync/atomic" "time" @@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v - } + config.rewriteRegex = rewriteRulesRegex(config.Rewrite) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Rewrite - for k, v := range config.rewriteRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - } - } + // Set rewrite path and raw path + rewritePath(config.rewriteRegex, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } + + diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 4bb74648c..534e45f44 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -94,36 +94,35 @@ func TestProxy(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse( "/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse( "/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/%%%%" + req.URL, _ = url.Parse("/api/new users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) - + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 855c8633a..0965e313f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,11 +1,8 @@ package middleware import ( - "net/http" - "regexp" - "strings" - "github.com/labstack/echo/v4" + "regexp" ) type ( @@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) - if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) - } - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v - } + config.rulesRegex = rewriteRulesRegex(config.Rules) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - // Rewrite - for k, v := range config.rulesRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - break - } - } + // Set rewrite path and raw path + rewritePath(config.rulesRegex, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index a9b3437ce..abf11b2f7 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/labstack/echo/v4" @@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) { })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse("/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse("/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - req.URL.Path = "/api/new users" - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/%%%%" - rec = httptest.NewRecorder() + req.URL, _ = url.Parse("/api/new users") e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) } // Issue #1086 @@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) { r := e.Router() // Rewrite old url to new one - e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{ + e.Pre(Rewrite(map[string]string{ "/old": "/new", }, - })) + )) // Route r.Add(http.MethodGet, "/new", func(c echo.Context) error {