From d48c5a120298c326cd160efe00ea6321c1c009c8 Mon Sep 17 00:00:00 2001 From: Sam Uong Date: Thu, 15 Oct 2020 17:17:09 -0400 Subject: [PATCH] Remove bare string used as context key Also, use sync/atomic's AddUint64() to generate IDs, which is simpler and avoids tying up a goroutine. Fixes #17 --- contextid.go | 24 ++++++++++++------------ contextid_test.go | 4 ++-- proxy.go | 6 +++--- proxyfinder.go | 2 +- proxyfinder_test.go | 3 ++- requestlogger.go | 8 +++++++- requestlogger_test.go | 2 +- 7 files changed, 28 insertions(+), 21 deletions(-) diff --git a/contextid.go b/contextid.go index 534f149..2057161 100644 --- a/contextid.go +++ b/contextid.go @@ -17,22 +17,22 @@ package main import ( "context" "net/http" + "sync/atomic" ) -// AddContextID wraps a http.Handler to add a strictly increasing -// uint to the context of the http.Request with the key "id" (string) -// as it passes through the request to the next handler. +type contextKey string + +const contextKeyID = contextKey("id") + +// AddContextID wraps a http.Handler to add a strictly increasing uint to the +// context of the http.Request with the key "id" as it passes through the +// request to the next handler. func AddContextID(next http.Handler) http.Handler { - // TODO(#17): Use sync/atomic AddUint64 instead of channel/goroutine - ids := make(chan uint) - go func() { - for id := uint(0); ; id++ { - ids <- id - } - }() + var id uint64 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // TODO(#17): Use package scoped type instead of string for key - ctx := context.WithValue(req.Context(), "id", <-ids) + ctx := context.WithValue( + req.Context(), contextKeyID, atomic.AddUint64(&id, 1), + ) next.ServeHTTP(w, req.WithContext(ctx)) }) } diff --git a/contextid_test.go b/contextid_test.go index 6d4f2c1..eb77e6a 100644 --- a/contextid_test.go +++ b/contextid_test.go @@ -37,13 +37,13 @@ func getIDFromRequest(t *testing.T, server *httptest.Server) uint { func TestContextID(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id, ok := r.Context().Value("id").(uint) + id, ok := r.Context().Value(contextKeyID).(uint64) assert.True(t, ok, "Unexpected type for context id value") _, err := w.Write([]byte(strconv.FormatUint(uint64(id), 10))) require.NoError(t, err) }) server := httptest.NewServer(AddContextID(handler)) defer server.Close() - assert.Equal(t, uint(0), getIDFromRequest(t, server)) assert.Equal(t, uint(1), getIDFromRequest(t, server)) + assert.Equal(t, uint(2), getIDFromRequest(t, server)) } diff --git a/proxy.go b/proxy.go index 54e63ff..7a95843 100644 --- a/proxy.go +++ b/proxy.go @@ -69,7 +69,7 @@ func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) { // Establish a connection to the server, or an upstream proxy. u, err := ph.transport.Proxy(req) - id := req.Context().Value("id") + id := req.Context().Value(contextKeyID) if err != nil { log.Printf("[%d] Error finding proxy for %v: %v", id, req.Host, err) w.WriteHeader(http.StatusInternalServerError) @@ -127,7 +127,7 @@ func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) { func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.Conn, error) { // can't hijack the connection to server, so can't just replay request via a Transport // need to dial and manually write connect header and read response - id := req.Context().Value("id") + id := req.Context().Value(contextKeyID) conn, err := net.Dial("tcp", proxy) if err != nil { log.Printf("[%d] Error dialling %s: %v", id, proxy, err) @@ -171,7 +171,7 @@ func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net. func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, auth *authenticator) { // Make a copy of the request body, in case we have to replay it (for authentication) var buf bytes.Buffer - id := req.Context().Value("id") + id := req.Context().Value(contextKeyID) if n, err := io.Copy(&buf, req.Body); err != nil { log.Printf("[%d] Error copying request body (got %d/%d): %v", id, n, req.ContentLength, err) diff --git a/proxyfinder.go b/proxyfinder.go index b13d6b9..1d5d15a 100644 --- a/proxyfinder.go +++ b/proxyfinder.go @@ -76,7 +76,7 @@ func (pf *ProxyFinder) checkForUpdates() { } func (pf *ProxyFinder) findProxyForRequest(req *http.Request) (*url.URL, error) { - id := req.Context().Value("id") + id := req.Context().Value(contextKeyID) if pf.fetcher == nil { log.Printf(`[%d] %s %s via "DIRECT"`, id, req.Method, req.URL) return nil, nil diff --git a/proxyfinder_test.go b/proxyfinder_test.go index de15786..5f1efa2 100644 --- a/proxyfinder_test.go +++ b/proxyfinder_test.go @@ -48,7 +48,8 @@ func TestFindProxyForRequest(t *testing.T) { pw := NewPACWrapper(PACData{Port: 1}) pf := NewProxyFinder(server.URL, pw) req := httptest.NewRequest(http.MethodGet, "https://www.test", nil) - req = req.WithContext(context.WithValue(req.Context(), "id", i)) + ctx := context.WithValue(req.Context(), contextKeyID, i) + req = req.WithContext(ctx) proxy, err := pf.findProxyForRequest(req) if test.expectError { assert.NotNil(t, err) diff --git a/requestlogger.go b/requestlogger.go index c18840a..a69b559 100644 --- a/requestlogger.go +++ b/requestlogger.go @@ -33,6 +33,12 @@ func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(sw, req) - log.Printf("[%v] %d %s %s", req.Context().Value("id"), sw.status, req.Method, req.URL) + log.Printf( + "[%v] %d %s %s", + req.Context().Value(contextKeyID), + sw.status, + req.Method, + req.URL, + ) }) } diff --git a/requestlogger_test.go b/requestlogger_test.go index 5f3cba2..c6b554c 100644 --- a/requestlogger_test.go +++ b/requestlogger_test.go @@ -34,7 +34,7 @@ func TestRequestLogger(t *testing.T) { }{ "No Status": {0, nil, "[] 200 GET /"}, "Given Status": {http.StatusNotFound, nil, "[] 404 GET /"}, - "Context": {http.StatusOK, AddContextID, "[0] 200 GET /"}, + "Context": {http.StatusOK, AddContextID, "[1] 200 GET /"}, } for name, tt := range tests { t.Run(name, func(t *testing.T) {