From 88dde29118a1f8d277cc7d7a4868f032addf33e7 Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 20 Jan 2024 15:08:56 +0000 Subject: [PATCH 1/2] Redirect all-caps article slugs to lowercase. Move this redirect functionality from nginx into Router, so that we can finally drop Perl from the nginx config. We'll need to push this further down into the backend ("frontend") apps when we get rid of Router, but there's still value in eliminating Perl from the request path sooner rather than later. --- handlers/redirect_handler.go | 24 ++++++++++ integration_tests/redirect_test.go | 45 ++++++++++++++++++ triemux/mux.go | 30 ++++++++++-- triemux/mux_test.go | 74 +++++++++++++++++------------- 4 files changed, 136 insertions(+), 37 deletions(-) diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 8fcdd7c7..932bd3d0 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -17,6 +17,7 @@ const ( redirectHandlerType = "redirect-handler" pathPreservingRedirectHandlerType = "path-preserving-redirect-handler" + downcaseRedirectHandlerType = "downcase-redirect-handler" ) func NewRedirectHandler(source, target string, preserve bool, temporary bool) http.Handler { @@ -87,3 +88,26 @@ func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWrit "redirect_type": pathPreservingRedirectHandlerType, }).Inc() } + +type downcaseRedirectHandler struct{} + +func NewDowncaseRedirectHandler() http.Handler { + return &downcaseRedirectHandler{} +} + +func (handler *downcaseRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + const status = http.StatusMovedPermanently + + target := strings.ToLower(r.URL.Path) + if r.URL.RawQuery != "" { + target += "?" + r.URL.RawQuery + } + + addCacheHeaders(w) + http.Redirect(w, r, target, status) + + redirectCountMetric.With(prometheus.Labels{ + "redirect_code": fmt.Sprintf("%d", status), + "redirect_type": downcaseRedirectHandlerType, + }).Inc() +} diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 91474c21..9bd20637 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -5,6 +5,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/ghttp" ) var _ = Describe("Redirection", func() { @@ -227,4 +228,48 @@ var _ = Describe("Redirection", func() { Expect(resp.Header.Get("Location")).To(Equal("https://param.servicegov.uk?_ga=12345&included-param=true")) }) }) + + Describe("path case normalisation rule", func() { + var recorder *ghttp.Server + + BeforeEach(func() { + recorder = startRecordingBackend() + addBackend("be", recorder.URL()) + addRoute("/guidance/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) + addRoute("/GUIDANCE/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) + reloadRoutes(apiPort) + }) + + AfterEach(func() { + recorder.Close() + }) + + It("should permanently redirect an ALL CAPS path to lowercase", func() { + resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG") + Expect(resp.StatusCode).To(Equal(301)) + Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig")) + }) + + It("should preserve case in the query string", func() { + resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG?Pig=Kunekune") + Expect(resp.StatusCode).To(Equal(301)) + Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig?Pig=Kunekune")) + }) + + It("should forward an all-lowercase path unchanged", func() { + resp := routerRequest(routerPort, "/guidance/keeping-a-pet-pig-or-micropig") + Expect(resp.StatusCode).To(Equal(200)) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) + beReq := recorder.ReceivedRequests()[0] + Expect(beReq.URL.RequestURI()).To(Equal("/guidance/keeping-a-pet-pig-or-micropig")) + }) + + It("should forward a mixed-case path unchanged", func() { + resp := routerRequest(routerPort, "/GUIDANCE/keeping-a-pet-pig-or-micropig") + Expect(resp.StatusCode).To(Equal(200)) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) + beReq := recorder.ReceivedRequests()[0] + Expect(beReq.URL.RequestURI()).To(Equal("/GUIDANCE/keeping-a-pet-pig-or-micropig")) + }) + }) }) diff --git a/triemux/mux.go b/triemux/mux.go index a669c9b0..57e09d18 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -5,9 +5,11 @@ package triemux import ( "net/http" + "regexp" "strings" "sync" + "github.com/alphagov/router/handlers" "github.com/alphagov/router/logger" "github.com/alphagov/router/trie" ) @@ -17,20 +19,22 @@ type Mux struct { exactTrie *trie.Trie[http.Handler] prefixTrie *trie.Trie[http.Handler] count int + downcaser http.Handler } // NewMux makes a new empty Mux. func NewMux() *Mux { return &Mux{ - exactTrie: trie.NewTrie[http.Handler](), + exactTrie: trie.NewTrie[http.Handler](), prefixTrie: trie.NewTrie[http.Handler](), + downcaser: handlers.NewDowncaseRedirectHandler(), } } -// ServeHTTP dispatches the request to a backend with a registered route -// matching the request path, or 404s. -// -// If the routing table is empty, return a 503. +// ServeHTTP forwards the request to a backend with a registered route matching +// the request path. Serves 404 when there is no backend. Serves 301 redirect +// to lowercase path when the URL path is entirely uppercase. Serves 503 when +// no routes are loaded. func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if mux.count == 0 { w.WriteHeader(http.StatusServiceUnavailable) @@ -42,6 +46,11 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if shouldRedirToLowercasePath(r.URL.Path) { + mux.downcaser.ServeHTTP(w, r) + return + } + handler, ok := mux.lookup(r.URL.Path) if !ok { http.NotFound(w, r) @@ -50,6 +59,17 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler.ServeHTTP(w, r) } +// shouldRedirToLowercasePath takes a URL path string (such as "/government/guidance") +// and returns: +// - true, if path is in all caps; for example: +// "/GOVERNMENT/GUIDANCE" -> true (should redirect to "/government/guidance") +// - false, otherwise; for example: +// "/GoVeRnMeNt/gUiDaNcE" -> false (should forward "/GoVeRnMeNt/gUiDaNcE" as-is) +func shouldRedirToLowercasePath(path string) (match bool) { + match, _ = regexp.MatchString(`^\/[A-Z]+[A-Z\W\d]+$`, path) + return +} + // lookup finds a URL path in the Mux and returns the corresponding handler. func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) { mux.mu.RLock() diff --git a/triemux/mux_test.go b/triemux/mux_test.go index 209a1265..8da612b9 100644 --- a/triemux/mux_test.go +++ b/triemux/mux_test.go @@ -10,47 +10,57 @@ import ( promtest "github.com/prometheus/client_golang/prometheus/testutil" ) -type SplitExample struct { - in string - out []string -} - -var splitExamples = []SplitExample{ - {"", []string{}}, - {"/", []string{}}, - {"foo", []string{"foo"}}, - {"/foo", []string{"foo"}}, - {"/füßball", []string{"füßball"}}, - {"/foo/bar", []string{"foo", "bar"}}, - {"///foo/bar", []string{"foo", "bar"}}, - {"foo/bar", []string{"foo", "bar"}}, - {"/foo/bar/", []string{"foo", "bar"}}, - {"/foo//bar/", []string{"foo", "bar"}}, - {"/foo/////bar/", []string{"foo", "bar"}}, -} - func TestSplitPath(t *testing.T) { - for _, ex := range splitExamples { - testSplitPath(t, ex) + tests := []struct { + in string + out []string + }{ + {"", []string{}}, + {"/", []string{}}, + {"foo", []string{"foo"}}, + {"/foo", []string{"foo"}}, + {"/füßball", []string{"füßball"}}, + {"/foo/bar", []string{"foo", "bar"}}, + {"///foo/bar", []string{"foo", "bar"}}, + {"foo/bar", []string{"foo", "bar"}}, + {"/foo/bar/", []string{"foo", "bar"}}, + {"/foo//bar/", []string{"foo", "bar"}}, + {"/foo/////bar/", []string{"foo", "bar"}}, + } + + for _, ex := range tests { + out := splitPath(ex.in) + if len(out) != len(ex.out) { + t.Errorf("splitPath(%v) was not %v", ex.in, ex.out) + } + for i := range ex.out { + if out[i] != ex.out[i] { + t.Errorf("splitPath(%v) differed from %v at component %d "+ + "(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i]) + } + } } } -func testSplitPath(t *testing.T, ex SplitExample) { - out := splitPath(ex.in) - if len(out) != len(ex.out) { - t.Errorf("splitPath(%v) was not %v", ex.in, ex.out) +func TestShouldRedirToLowercasePath(t *testing.T) { + tests := []struct { + in string + out bool + }{ + {"/GOVERNMENT/GUIDANCE", true}, + {"/GoVeRnMeNt/gUiDaNcE", false}, + {"/government/guidance", false}, } - for i := range ex.out { - if out[i] != ex.out[i] { - t.Errorf("splitPath(%v) differed from %v at component %d "+ - "(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i]) + + for _, ex := range tests { + out := shouldRedirToLowercasePath(ex.in) + if out != ex.out { + t.Errorf("shouldRedirToLowercasePath(%v): expected %v, got %v", ex.in, ex.out, out) } } } -type DummyHandler struct { - id string -} +type DummyHandler struct{ id string } func (dh *DummyHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {} From a37389cb4a2c3464e7c476614d0a33eb3f0a04aa Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 20 Jan 2024 15:10:24 +0000 Subject: [PATCH 2/2] Trivial style cleanup in integration tests. - Use the conventional `r` and `w` short names for http.Request and http.ResponseWriter parameters. - Factor out ghttp.Server declaration. - Push recorderURL down into the only test case that uses it. - Fix missing err check on a call to URL.Parse(). No functional change (apart from the missing err check), just readability. --- handlers/redirect_handler.go | 32 ++++++++++++------------ integration_tests/proxy_function_test.go | 29 ++++----------------- 2 files changed, 21 insertions(+), 40 deletions(-) diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 932bd3d0..125fb60c 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -31,16 +31,16 @@ func NewRedirectHandler(source, target string, preserve bool, temporary bool) ht return &redirectHandler{target, statusMoved} } -func addCacheHeaders(writer http.ResponseWriter) { - writer.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123)) - writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second)) +func addCacheHeaders(w http.ResponseWriter) { + w.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123)) + w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second)) } -func addGAQueryParam(target string, request *http.Request) string { - if ga := request.URL.Query().Get("_ga"); ga != "" { +func addGAQueryParam(target string, r *http.Request) string { + if ga := r.URL.Query().Get("_ga"); ga != "" { u, err := url.Parse(target) if err != nil { - defer logger.NotifySentry(logger.ReportableError{Error: err, Request: request}) + defer logger.NotifySentry(logger.ReportableError{Error: err, Request: r}) return target } values := u.Query() @@ -56,11 +56,11 @@ type redirectHandler struct { code int } -func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - addCacheHeaders(writer) +func (handler *redirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + addCacheHeaders(w) - target := addGAQueryParam(handler.url, request) - http.Redirect(writer, request, target, handler.code) + target := addGAQueryParam(handler.url, r) + http.Redirect(w, r, target, handler.code) redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), @@ -74,14 +74,14 @@ type pathPreservingRedirectHandler struct { code int } -func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - target := handler.targetPrefix + strings.TrimPrefix(request.URL.Path, handler.sourcePrefix) - if request.URL.RawQuery != "" { - target += "?" + request.URL.RawQuery +func (handler *pathPreservingRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + target := handler.targetPrefix + strings.TrimPrefix(r.URL.Path, handler.sourcePrefix) + if r.URL.RawQuery != "" { + target += "?" + r.URL.RawQuery } - addCacheHeaders(writer) - http.Redirect(writer, request, target, handler.code) + addCacheHeaders(w) + http.Redirect(w, r, target, handler.code) redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index d7c93fc5..e9782347 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -15,6 +15,7 @@ import ( ) var _ = Describe("Functioning as a reverse proxy", func() { + var recorder *ghttp.Server Describe("connecting to the backend", func() { It("should return a 502 if the connection to the backend is refused", func() { @@ -74,10 +75,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("response header timeout", func() { - var ( - tarpit1 *httptest.Server - tarpit2 *httptest.Server - ) + var tarpit1, tarpit2 *httptest.Server BeforeEach(func() { err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"}) @@ -125,14 +123,8 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("header handling", func() { - var ( - recorder *ghttp.Server - recorderURL *url.URL - ) - BeforeEach(func() { recorder = startRecordingBackend() - recorderURL, _ = url.Parse(recorder.URL()) addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) @@ -161,6 +153,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Expect(resp.StatusCode).To(Equal(200)) + recorderURL, err := url.Parse(recorder.URL()) + Expect(err).NotTo(HaveOccurred()) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Host).To(Equal(recorderURL.Host)) @@ -252,10 +247,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("request verb, path, query and body handling", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -312,10 +303,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling a backend with a non '/' path", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()+"/something") @@ -347,10 +334,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling HTTP/1.0 requests", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -384,8 +367,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling requests to a HTTPS backend", func() { - var recorder *ghttp.Server - BeforeEach(func() { err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) Expect(err).NotTo(HaveOccurred())