diff --git a/handlers/redirect_handler.go b/handlers/redirect_handler.go index 8fcdd7c7..125fb60c 100644 --- a/handlers/redirect_handler.go +++ b/handlers/redirect_handler.go @@ -17,6 +17,7 @@ const ( redirectHandlerType = "redirect-handler" pathPreservingRedirectHandlerType = "path-preserving-redirect-handler" + downcaseRedirectHandlerType = "downcase-redirect-handler" ) func NewRedirectHandler(source, target string, preserve bool, temporary bool) http.Handler { @@ -30,16 +31,16 @@ func NewRedirectHandler(source, target string, preserve bool, temporary bool) ht return &redirectHandler{target, statusMoved} } -func addCacheHeaders(writer http.ResponseWriter) { - writer.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123)) - writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second)) +func addCacheHeaders(w http.ResponseWriter) { + w.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123)) + w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second)) } -func addGAQueryParam(target string, request *http.Request) string { - if ga := request.URL.Query().Get("_ga"); ga != "" { +func addGAQueryParam(target string, r *http.Request) string { + if ga := r.URL.Query().Get("_ga"); ga != "" { u, err := url.Parse(target) if err != nil { - defer logger.NotifySentry(logger.ReportableError{Error: err, Request: request}) + defer logger.NotifySentry(logger.ReportableError{Error: err, Request: r}) return target } values := u.Query() @@ -55,11 +56,11 @@ type redirectHandler struct { code int } -func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - addCacheHeaders(writer) +func (handler *redirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + addCacheHeaders(w) - target := addGAQueryParam(handler.url, request) - http.Redirect(writer, request, target, handler.code) + target := addGAQueryParam(handler.url, r) + http.Redirect(w, r, target, handler.code) redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), @@ -73,17 +74,40 @@ type pathPreservingRedirectHandler struct { code int } -func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - target := handler.targetPrefix + strings.TrimPrefix(request.URL.Path, handler.sourcePrefix) - if request.URL.RawQuery != "" { - target += "?" + request.URL.RawQuery +func (handler *pathPreservingRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + target := handler.targetPrefix + strings.TrimPrefix(r.URL.Path, handler.sourcePrefix) + if r.URL.RawQuery != "" { + target += "?" + r.URL.RawQuery } - addCacheHeaders(writer) - http.Redirect(writer, request, target, handler.code) + addCacheHeaders(w) + http.Redirect(w, r, target, handler.code) redirectCountMetric.With(prometheus.Labels{ "redirect_code": fmt.Sprintf("%d", handler.code), "redirect_type": pathPreservingRedirectHandlerType, }).Inc() } + +type downcaseRedirectHandler struct{} + +func NewDowncaseRedirectHandler() http.Handler { + return &downcaseRedirectHandler{} +} + +func (handler *downcaseRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + const status = http.StatusMovedPermanently + + target := strings.ToLower(r.URL.Path) + if r.URL.RawQuery != "" { + target += "?" + r.URL.RawQuery + } + + addCacheHeaders(w) + http.Redirect(w, r, target, status) + + redirectCountMetric.With(prometheus.Labels{ + "redirect_code": fmt.Sprintf("%d", status), + "redirect_type": downcaseRedirectHandlerType, + }).Inc() +} diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index d7c93fc5..e9782347 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -15,6 +15,7 @@ import ( ) var _ = Describe("Functioning as a reverse proxy", func() { + var recorder *ghttp.Server Describe("connecting to the backend", func() { It("should return a 502 if the connection to the backend is refused", func() { @@ -74,10 +75,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("response header timeout", func() { - var ( - tarpit1 *httptest.Server - tarpit2 *httptest.Server - ) + var tarpit1, tarpit2 *httptest.Server BeforeEach(func() { err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"}) @@ -125,14 +123,8 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("header handling", func() { - var ( - recorder *ghttp.Server - recorderURL *url.URL - ) - BeforeEach(func() { recorder = startRecordingBackend() - recorderURL, _ = url.Parse(recorder.URL()) addBackend("backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) @@ -161,6 +153,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Expect(resp.StatusCode).To(Equal(200)) + recorderURL, err := url.Parse(recorder.URL()) + Expect(err).NotTo(HaveOccurred()) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] Expect(beReq.Host).To(Equal(recorderURL.Host)) @@ -252,10 +247,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("request verb, path, query and body handling", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -312,10 +303,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling a backend with a non '/' path", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()+"/something") @@ -347,10 +334,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling HTTP/1.0 requests", func() { - var ( - recorder *ghttp.Server - ) - BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -384,8 +367,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling requests to a HTTPS backend", func() { - var recorder *ghttp.Server - BeforeEach(func() { err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) Expect(err).NotTo(HaveOccurred()) diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 91474c21..9bd20637 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -5,6 +5,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/ghttp" ) var _ = Describe("Redirection", func() { @@ -227,4 +228,48 @@ var _ = Describe("Redirection", func() { Expect(resp.Header.Get("Location")).To(Equal("https://param.servicegov.uk?_ga=12345&included-param=true")) }) }) + + Describe("path case normalisation rule", func() { + var recorder *ghttp.Server + + BeforeEach(func() { + recorder = startRecordingBackend() + addBackend("be", recorder.URL()) + addRoute("/guidance/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) + addRoute("/GUIDANCE/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) + reloadRoutes(apiPort) + }) + + AfterEach(func() { + recorder.Close() + }) + + It("should permanently redirect an ALL CAPS path to lowercase", func() { + resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG") + Expect(resp.StatusCode).To(Equal(301)) + Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig")) + }) + + It("should preserve case in the query string", func() { + resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG?Pig=Kunekune") + Expect(resp.StatusCode).To(Equal(301)) + Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig?Pig=Kunekune")) + }) + + It("should forward an all-lowercase path unchanged", func() { + resp := routerRequest(routerPort, "/guidance/keeping-a-pet-pig-or-micropig") + Expect(resp.StatusCode).To(Equal(200)) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) + beReq := recorder.ReceivedRequests()[0] + Expect(beReq.URL.RequestURI()).To(Equal("/guidance/keeping-a-pet-pig-or-micropig")) + }) + + It("should forward a mixed-case path unchanged", func() { + resp := routerRequest(routerPort, "/GUIDANCE/keeping-a-pet-pig-or-micropig") + Expect(resp.StatusCode).To(Equal(200)) + Expect(recorder.ReceivedRequests()).To(HaveLen(1)) + beReq := recorder.ReceivedRequests()[0] + Expect(beReq.URL.RequestURI()).To(Equal("/GUIDANCE/keeping-a-pet-pig-or-micropig")) + }) + }) }) diff --git a/triemux/mux.go b/triemux/mux.go index a669c9b0..57e09d18 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -5,9 +5,11 @@ package triemux import ( "net/http" + "regexp" "strings" "sync" + "github.com/alphagov/router/handlers" "github.com/alphagov/router/logger" "github.com/alphagov/router/trie" ) @@ -17,20 +19,22 @@ type Mux struct { exactTrie *trie.Trie[http.Handler] prefixTrie *trie.Trie[http.Handler] count int + downcaser http.Handler } // NewMux makes a new empty Mux. func NewMux() *Mux { return &Mux{ - exactTrie: trie.NewTrie[http.Handler](), + exactTrie: trie.NewTrie[http.Handler](), prefixTrie: trie.NewTrie[http.Handler](), + downcaser: handlers.NewDowncaseRedirectHandler(), } } -// ServeHTTP dispatches the request to a backend with a registered route -// matching the request path, or 404s. -// -// If the routing table is empty, return a 503. +// ServeHTTP forwards the request to a backend with a registered route matching +// the request path. Serves 404 when there is no backend. Serves 301 redirect +// to lowercase path when the URL path is entirely uppercase. Serves 503 when +// no routes are loaded. func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if mux.count == 0 { w.WriteHeader(http.StatusServiceUnavailable) @@ -42,6 +46,11 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if shouldRedirToLowercasePath(r.URL.Path) { + mux.downcaser.ServeHTTP(w, r) + return + } + handler, ok := mux.lookup(r.URL.Path) if !ok { http.NotFound(w, r) @@ -50,6 +59,17 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler.ServeHTTP(w, r) } +// shouldRedirToLowercasePath takes a URL path string (such as "/government/guidance") +// and returns: +// - true, if path is in all caps; for example: +// "/GOVERNMENT/GUIDANCE" -> true (should redirect to "/government/guidance") +// - false, otherwise; for example: +// "/GoVeRnMeNt/gUiDaNcE" -> false (should forward "/GoVeRnMeNt/gUiDaNcE" as-is) +func shouldRedirToLowercasePath(path string) (match bool) { + match, _ = regexp.MatchString(`^\/[A-Z]+[A-Z\W\d]+$`, path) + return +} + // lookup finds a URL path in the Mux and returns the corresponding handler. func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) { mux.mu.RLock() diff --git a/triemux/mux_test.go b/triemux/mux_test.go index 209a1265..8da612b9 100644 --- a/triemux/mux_test.go +++ b/triemux/mux_test.go @@ -10,47 +10,57 @@ import ( promtest "github.com/prometheus/client_golang/prometheus/testutil" ) -type SplitExample struct { - in string - out []string -} - -var splitExamples = []SplitExample{ - {"", []string{}}, - {"/", []string{}}, - {"foo", []string{"foo"}}, - {"/foo", []string{"foo"}}, - {"/füßball", []string{"füßball"}}, - {"/foo/bar", []string{"foo", "bar"}}, - {"///foo/bar", []string{"foo", "bar"}}, - {"foo/bar", []string{"foo", "bar"}}, - {"/foo/bar/", []string{"foo", "bar"}}, - {"/foo//bar/", []string{"foo", "bar"}}, - {"/foo/////bar/", []string{"foo", "bar"}}, -} - func TestSplitPath(t *testing.T) { - for _, ex := range splitExamples { - testSplitPath(t, ex) + tests := []struct { + in string + out []string + }{ + {"", []string{}}, + {"/", []string{}}, + {"foo", []string{"foo"}}, + {"/foo", []string{"foo"}}, + {"/füßball", []string{"füßball"}}, + {"/foo/bar", []string{"foo", "bar"}}, + {"///foo/bar", []string{"foo", "bar"}}, + {"foo/bar", []string{"foo", "bar"}}, + {"/foo/bar/", []string{"foo", "bar"}}, + {"/foo//bar/", []string{"foo", "bar"}}, + {"/foo/////bar/", []string{"foo", "bar"}}, + } + + for _, ex := range tests { + out := splitPath(ex.in) + if len(out) != len(ex.out) { + t.Errorf("splitPath(%v) was not %v", ex.in, ex.out) + } + for i := range ex.out { + if out[i] != ex.out[i] { + t.Errorf("splitPath(%v) differed from %v at component %d "+ + "(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i]) + } + } } } -func testSplitPath(t *testing.T, ex SplitExample) { - out := splitPath(ex.in) - if len(out) != len(ex.out) { - t.Errorf("splitPath(%v) was not %v", ex.in, ex.out) +func TestShouldRedirToLowercasePath(t *testing.T) { + tests := []struct { + in string + out bool + }{ + {"/GOVERNMENT/GUIDANCE", true}, + {"/GoVeRnMeNt/gUiDaNcE", false}, + {"/government/guidance", false}, } - for i := range ex.out { - if out[i] != ex.out[i] { - t.Errorf("splitPath(%v) differed from %v at component %d "+ - "(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i]) + + for _, ex := range tests { + out := shouldRedirToLowercasePath(ex.in) + if out != ex.out { + t.Errorf("shouldRedirToLowercasePath(%v): expected %v, got %v", ex.in, ex.out, out) } } } -type DummyHandler struct { - id string -} +type DummyHandler struct{ id string } func (dh *DummyHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {}