Skip to content

Commit

Permalink
Remove bare string used as context key
Browse files Browse the repository at this point in the history
Also, use sync/atomic's AddUint64() to generate IDs, which is simpler
and avoids tying up a goroutine.

Fixes #17
  • Loading branch information
samuong committed Oct 16, 2020
1 parent 25fd9e1 commit d48c5a1
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 21 deletions.
24 changes: 12 additions & 12 deletions contextid.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ package main
import (
"context"
"net/http"
"sync/atomic"
)

// AddContextID wraps a http.Handler to add a strictly increasing
// uint to the context of the http.Request with the key "id" (string)
// as it passes through the request to the next handler.
type contextKey string

const contextKeyID = contextKey("id")

// AddContextID wraps a http.Handler to add a strictly increasing uint to the
// context of the http.Request with the key "id" as it passes through the
// request to the next handler.
func AddContextID(next http.Handler) http.Handler {
// TODO(#17): Use sync/atomic AddUint64 instead of channel/goroutine
ids := make(chan uint)
go func() {
for id := uint(0); ; id++ {
ids <- id
}
}()
var id uint64
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// TODO(#17): Use package scoped type instead of string for key
ctx := context.WithValue(req.Context(), "id", <-ids)
ctx := context.WithValue(
req.Context(), contextKeyID, atomic.AddUint64(&id, 1),
)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
4 changes: 2 additions & 2 deletions contextid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ func getIDFromRequest(t *testing.T, server *httptest.Server) uint {

func TestContextID(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id, ok := r.Context().Value("id").(uint)
id, ok := r.Context().Value(contextKeyID).(uint64)
assert.True(t, ok, "Unexpected type for context id value")
_, err := w.Write([]byte(strconv.FormatUint(uint64(id), 10)))
require.NoError(t, err)
})
server := httptest.NewServer(AddContextID(handler))
defer server.Close()
assert.Equal(t, uint(0), getIDFromRequest(t, server))
assert.Equal(t, uint(1), getIDFromRequest(t, server))
assert.Equal(t, uint(2), getIDFromRequest(t, server))
}
6 changes: 3 additions & 3 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
// Establish a connection to the server, or an upstream proxy.
u, err := ph.transport.Proxy(req)
id := req.Context().Value("id")
id := req.Context().Value(contextKeyID)
if err != nil {
log.Printf("[%d] Error finding proxy for %v: %v", id, req.Host, err)
w.WriteHeader(http.StatusInternalServerError)
Expand Down Expand Up @@ -127,7 +127,7 @@ func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.Conn, error) {
// can't hijack the connection to server, so can't just replay request via a Transport
// need to dial and manually write connect header and read response
id := req.Context().Value("id")
id := req.Context().Value(contextKeyID)
conn, err := net.Dial("tcp", proxy)
if err != nil {
log.Printf("[%d] Error dialling %s: %v", id, proxy, err)
Expand Down Expand Up @@ -171,7 +171,7 @@ func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.
func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, auth *authenticator) {
// Make a copy of the request body, in case we have to replay it (for authentication)
var buf bytes.Buffer
id := req.Context().Value("id")
id := req.Context().Value(contextKeyID)
if n, err := io.Copy(&buf, req.Body); err != nil {
log.Printf("[%d] Error copying request body (got %d/%d): %v",
id, n, req.ContentLength, err)
Expand Down
2 changes: 1 addition & 1 deletion proxyfinder.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (pf *ProxyFinder) checkForUpdates() {
}

func (pf *ProxyFinder) findProxyForRequest(req *http.Request) (*url.URL, error) {
id := req.Context().Value("id")
id := req.Context().Value(contextKeyID)
if pf.fetcher == nil {
log.Printf(`[%d] %s %s via "DIRECT"`, id, req.Method, req.URL)
return nil, nil
Expand Down
3 changes: 2 additions & 1 deletion proxyfinder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ func TestFindProxyForRequest(t *testing.T) {
pw := NewPACWrapper(PACData{Port: 1})
pf := NewProxyFinder(server.URL, pw)
req := httptest.NewRequest(http.MethodGet, "https://www.test", nil)
req = req.WithContext(context.WithValue(req.Context(), "id", i))
ctx := context.WithValue(req.Context(), contextKeyID, i)
req = req.WithContext(ctx)
proxy, err := pf.findProxyForRequest(req)
if test.expectError {
assert.NotNil(t, err)
Expand Down
8 changes: 7 additions & 1 deletion requestlogger.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, req)
log.Printf("[%v] %d %s %s", req.Context().Value("id"), sw.status, req.Method, req.URL)
log.Printf(
"[%v] %d %s %s",
req.Context().Value(contextKeyID),
sw.status,
req.Method,
req.URL,
)
})
}
2 changes: 1 addition & 1 deletion requestlogger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestRequestLogger(t *testing.T) {
}{
"No Status": {0, nil, "[<nil>] 200 GET /"},
"Given Status": {http.StatusNotFound, nil, "[<nil>] 404 GET /"},
"Context": {http.StatusOK, AddContextID, "[0] 200 GET /"},
"Context": {http.StatusOK, AddContextID, "[1] 200 GET /"},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
Expand Down

0 comments on commit d48c5a1

Please sign in to comment.