Skip to content

Commit

Permalink
Merge pull request #424 from alphagov/sengi/lowercase
Browse files Browse the repository at this point in the history
Redirect all-caps article slugs to lowercase.
  • Loading branch information
sengi authored Jan 22, 2024
2 parents 447c05c + a37389c commit 2e53c1a
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 77 deletions.
56 changes: 40 additions & 16 deletions handlers/redirect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const (

redirectHandlerType = "redirect-handler"
pathPreservingRedirectHandlerType = "path-preserving-redirect-handler"
downcaseRedirectHandlerType = "downcase-redirect-handler"
)

func NewRedirectHandler(source, target string, preserve bool, temporary bool) http.Handler {
Expand All @@ -30,16 +31,16 @@ func NewRedirectHandler(source, target string, preserve bool, temporary bool) ht
return &redirectHandler{target, statusMoved}
}

func addCacheHeaders(writer http.ResponseWriter) {
writer.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123))
writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second))
func addCacheHeaders(w http.ResponseWriter) {
w.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123))
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second))
}

func addGAQueryParam(target string, request *http.Request) string {
if ga := request.URL.Query().Get("_ga"); ga != "" {
func addGAQueryParam(target string, r *http.Request) string {
if ga := r.URL.Query().Get("_ga"); ga != "" {
u, err := url.Parse(target)
if err != nil {
defer logger.NotifySentry(logger.ReportableError{Error: err, Request: request})
defer logger.NotifySentry(logger.ReportableError{Error: err, Request: r})
return target
}
values := u.Query()
Expand All @@ -55,11 +56,11 @@ type redirectHandler struct {
code int
}

func (handler *redirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
addCacheHeaders(writer)
func (handler *redirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
addCacheHeaders(w)

target := addGAQueryParam(handler.url, request)
http.Redirect(writer, request, target, handler.code)
target := addGAQueryParam(handler.url, r)
http.Redirect(w, r, target, handler.code)

redirectCountMetric.With(prometheus.Labels{
"redirect_code": fmt.Sprintf("%d", handler.code),
Expand All @@ -73,17 +74,40 @@ type pathPreservingRedirectHandler struct {
code int
}

func (handler *pathPreservingRedirectHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
target := handler.targetPrefix + strings.TrimPrefix(request.URL.Path, handler.sourcePrefix)
if request.URL.RawQuery != "" {
target += "?" + request.URL.RawQuery
func (handler *pathPreservingRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
target := handler.targetPrefix + strings.TrimPrefix(r.URL.Path, handler.sourcePrefix)
if r.URL.RawQuery != "" {
target += "?" + r.URL.RawQuery
}

addCacheHeaders(writer)
http.Redirect(writer, request, target, handler.code)
addCacheHeaders(w)
http.Redirect(w, r, target, handler.code)

redirectCountMetric.With(prometheus.Labels{
"redirect_code": fmt.Sprintf("%d", handler.code),
"redirect_type": pathPreservingRedirectHandlerType,
}).Inc()
}

type downcaseRedirectHandler struct{}

func NewDowncaseRedirectHandler() http.Handler {
return &downcaseRedirectHandler{}
}

func (handler *downcaseRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
const status = http.StatusMovedPermanently

target := strings.ToLower(r.URL.Path)
if r.URL.RawQuery != "" {
target += "?" + r.URL.RawQuery
}

addCacheHeaders(w)
http.Redirect(w, r, target, status)

redirectCountMetric.With(prometheus.Labels{
"redirect_code": fmt.Sprintf("%d", status),
"redirect_type": downcaseRedirectHandlerType,
}).Inc()
}
29 changes: 5 additions & 24 deletions integration_tests/proxy_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

var _ = Describe("Functioning as a reverse proxy", func() {
var recorder *ghttp.Server

Describe("connecting to the backend", func() {
It("should return a 502 if the connection to the backend is refused", func() {
Expand Down Expand Up @@ -74,10 +75,7 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("response header timeout", func() {
var (
tarpit1 *httptest.Server
tarpit2 *httptest.Server
)
var tarpit1, tarpit2 *httptest.Server

BeforeEach(func() {
err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"})
Expand Down Expand Up @@ -125,14 +123,8 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("header handling", func() {
var (
recorder *ghttp.Server
recorderURL *url.URL
)

BeforeEach(func() {
recorder = startRecordingBackend()
recorderURL, _ = url.Parse(recorder.URL())
addBackend("backend", recorder.URL())
addRoute("/foo", NewBackendRoute("backend", "prefix"))
reloadRoutes(apiPort)
Expand Down Expand Up @@ -161,6 +153,9 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})
Expect(resp.StatusCode).To(Equal(200))

recorderURL, err := url.Parse(recorder.URL())
Expect(err).NotTo(HaveOccurred())

Expect(recorder.ReceivedRequests()).To(HaveLen(1))
beReq := recorder.ReceivedRequests()[0]
Expect(beReq.Host).To(Equal(recorderURL.Host))
Expand Down Expand Up @@ -252,10 +247,6 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("request verb, path, query and body handling", func() {
var (
recorder *ghttp.Server
)

BeforeEach(func() {
recorder = startRecordingBackend()
addBackend("backend", recorder.URL())
Expand Down Expand Up @@ -312,10 +303,6 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("handling a backend with a non '/' path", func() {
var (
recorder *ghttp.Server
)

BeforeEach(func() {
recorder = startRecordingBackend()
addBackend("backend", recorder.URL()+"/something")
Expand Down Expand Up @@ -347,10 +334,6 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("handling HTTP/1.0 requests", func() {
var (
recorder *ghttp.Server
)

BeforeEach(func() {
recorder = startRecordingBackend()
addBackend("backend", recorder.URL())
Expand Down Expand Up @@ -384,8 +367,6 @@ var _ = Describe("Functioning as a reverse proxy", func() {
})

Describe("handling requests to a HTTPS backend", func() {
var recorder *ghttp.Server

BeforeEach(func() {
err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"})
Expect(err).NotTo(HaveOccurred())
Expand Down
45 changes: 45 additions & 0 deletions integration_tests/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/ghttp"
)

var _ = Describe("Redirection", func() {
Expand Down Expand Up @@ -227,4 +228,48 @@ var _ = Describe("Redirection", func() {
Expect(resp.Header.Get("Location")).To(Equal("https://param.servicegov.uk?_ga=12345&included-param=true"))
})
})

Describe("path case normalisation rule", func() {
var recorder *ghttp.Server

BeforeEach(func() {
recorder = startRecordingBackend()
addBackend("be", recorder.URL())
addRoute("/guidance/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact"))
addRoute("/GUIDANCE/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact"))
reloadRoutes(apiPort)
})

AfterEach(func() {
recorder.Close()
})

It("should permanently redirect an ALL CAPS path to lowercase", func() {
resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG")
Expect(resp.StatusCode).To(Equal(301))
Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig"))
})

It("should preserve case in the query string", func() {
resp := routerRequest(routerPort, "/GUIDANCE/KEEPING-A-PET-PIG-OR-MICROPIG?Pig=Kunekune")
Expect(resp.StatusCode).To(Equal(301))
Expect(resp.Header.Get("Location")).To(Equal("/guidance/keeping-a-pet-pig-or-micropig?Pig=Kunekune"))
})

It("should forward an all-lowercase path unchanged", func() {
resp := routerRequest(routerPort, "/guidance/keeping-a-pet-pig-or-micropig")
Expect(resp.StatusCode).To(Equal(200))
Expect(recorder.ReceivedRequests()).To(HaveLen(1))
beReq := recorder.ReceivedRequests()[0]
Expect(beReq.URL.RequestURI()).To(Equal("/guidance/keeping-a-pet-pig-or-micropig"))
})

It("should forward a mixed-case path unchanged", func() {
resp := routerRequest(routerPort, "/GUIDANCE/keeping-a-pet-pig-or-micropig")
Expect(resp.StatusCode).To(Equal(200))
Expect(recorder.ReceivedRequests()).To(HaveLen(1))
beReq := recorder.ReceivedRequests()[0]
Expect(beReq.URL.RequestURI()).To(Equal("/GUIDANCE/keeping-a-pet-pig-or-micropig"))
})
})
})
30 changes: 25 additions & 5 deletions triemux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ package triemux

import (
"net/http"
"regexp"
"strings"
"sync"

"github.com/alphagov/router/handlers"
"github.com/alphagov/router/logger"
"github.com/alphagov/router/trie"
)
Expand All @@ -17,20 +19,22 @@ type Mux struct {
exactTrie *trie.Trie[http.Handler]
prefixTrie *trie.Trie[http.Handler]
count int
downcaser http.Handler
}

// NewMux makes a new empty Mux.
func NewMux() *Mux {
return &Mux{
exactTrie: trie.NewTrie[http.Handler](),
exactTrie: trie.NewTrie[http.Handler](),
prefixTrie: trie.NewTrie[http.Handler](),
downcaser: handlers.NewDowncaseRedirectHandler(),
}
}

// ServeHTTP dispatches the request to a backend with a registered route
// matching the request path, or 404s.
//
// If the routing table is empty, return a 503.
// ServeHTTP forwards the request to a backend with a registered route matching
// the request path. Serves 404 when there is no backend. Serves 301 redirect
// to lowercase path when the URL path is entirely uppercase. Serves 503 when
// no routes are loaded.
func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if mux.count == 0 {
w.WriteHeader(http.StatusServiceUnavailable)
Expand All @@ -42,6 +46,11 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if shouldRedirToLowercasePath(r.URL.Path) {
mux.downcaser.ServeHTTP(w, r)
return
}

handler, ok := mux.lookup(r.URL.Path)
if !ok {
http.NotFound(w, r)
Expand All @@ -50,6 +59,17 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
}

// shouldRedirToLowercasePath takes a URL path string (such as "/government/guidance")
// and returns:
// - true, if path is in all caps; for example:
// "/GOVERNMENT/GUIDANCE" -> true (should redirect to "/government/guidance")
// - false, otherwise; for example:
// "/GoVeRnMeNt/gUiDaNcE" -> false (should forward "/GoVeRnMeNt/gUiDaNcE" as-is)
func shouldRedirToLowercasePath(path string) (match bool) {
match, _ = regexp.MatchString(`^\/[A-Z]+[A-Z\W\d]+$`, path)
return
}

// lookup finds a URL path in the Mux and returns the corresponding handler.
func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) {
mux.mu.RLock()
Expand Down
74 changes: 42 additions & 32 deletions triemux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,57 @@ import (
promtest "github.com/prometheus/client_golang/prometheus/testutil"
)

type SplitExample struct {
in string
out []string
}

var splitExamples = []SplitExample{
{"", []string{}},
{"/", []string{}},
{"foo", []string{"foo"}},
{"/foo", []string{"foo"}},
{"/füßball", []string{"füßball"}},
{"/foo/bar", []string{"foo", "bar"}},
{"///foo/bar", []string{"foo", "bar"}},
{"foo/bar", []string{"foo", "bar"}},
{"/foo/bar/", []string{"foo", "bar"}},
{"/foo//bar/", []string{"foo", "bar"}},
{"/foo/////bar/", []string{"foo", "bar"}},
}

func TestSplitPath(t *testing.T) {
for _, ex := range splitExamples {
testSplitPath(t, ex)
tests := []struct {
in string
out []string
}{
{"", []string{}},
{"/", []string{}},
{"foo", []string{"foo"}},
{"/foo", []string{"foo"}},
{"/füßball", []string{"füßball"}},
{"/foo/bar", []string{"foo", "bar"}},
{"///foo/bar", []string{"foo", "bar"}},
{"foo/bar", []string{"foo", "bar"}},
{"/foo/bar/", []string{"foo", "bar"}},
{"/foo//bar/", []string{"foo", "bar"}},
{"/foo/////bar/", []string{"foo", "bar"}},
}

for _, ex := range tests {
out := splitPath(ex.in)
if len(out) != len(ex.out) {
t.Errorf("splitPath(%v) was not %v", ex.in, ex.out)
}
for i := range ex.out {
if out[i] != ex.out[i] {
t.Errorf("splitPath(%v) differed from %v at component %d "+
"(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i])
}
}
}
}

func testSplitPath(t *testing.T, ex SplitExample) {
out := splitPath(ex.in)
if len(out) != len(ex.out) {
t.Errorf("splitPath(%v) was not %v", ex.in, ex.out)
func TestShouldRedirToLowercasePath(t *testing.T) {
tests := []struct {
in string
out bool
}{
{"/GOVERNMENT/GUIDANCE", true},
{"/GoVeRnMeNt/gUiDaNcE", false},
{"/government/guidance", false},
}
for i := range ex.out {
if out[i] != ex.out[i] {
t.Errorf("splitPath(%v) differed from %v at component %d "+
"(expected %v, got %v)", out, ex.out, i, ex.out[i], out[i])

for _, ex := range tests {
out := shouldRedirToLowercasePath(ex.in)
if out != ex.out {
t.Errorf("shouldRedirToLowercasePath(%v): expected %v, got %v", ex.in, ex.out, out)
}
}
}

type DummyHandler struct {
id string
}
type DummyHandler struct{ id string }

func (dh *DummyHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {}

Expand Down

0 comments on commit 2e53c1a

Please sign in to comment.