From c08ad4c8e56d059c6215356a6d9cb8d7cea39a0d Mon Sep 17 00:00:00 2001 From: Chris Roche Date: Mon, 2 Oct 2017 09:25:32 -0700 Subject: [PATCH] HTTPHandler: Expose http.ResponseWriter optional interfaces (#18) * checkpoint * Expose all optional http.ResponseWriter interfaces * http.Pusher only exists in go1.8+ * Add Go 1.9 to CI --- .travis.yml | 5 +- stat_handler.go | 88 +++++++++++++++------- stat_handler_test.go | 90 ++++++++++++++++++++++ stat_handler_wrapper.go | 113 ++++++++++++++++++++++++++++ stat_handler_wrapper_1.7.go | 60 +++++++++++++++ stat_handler_wrapper_1.7_test.go | 79 +++++++++++++++++++ stat_handler_wrapper_test.go | 125 +++++++++++++++++++++++++++++++ 7 files changed, 530 insertions(+), 30 deletions(-) create mode 100644 stat_handler_test.go create mode 100644 stat_handler_wrapper.go create mode 100644 stat_handler_wrapper_1.7.go create mode 100644 stat_handler_wrapper_1.7_test.go create mode 100644 stat_handler_wrapper_test.go diff --git a/.travis.yml b/.travis.yml index 3d79df7..1ad9abb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,9 @@ sudo: required language: go go: - - 1.7.4 - - 1.8.2 + - 1.7 + - 1.8 + - 1.9 before_install: mkdir -p $GOPATH/bin install: make install diff --git a/stat_handler.go b/stat_handler.go index 3dd7bd0..0e582f5 100644 --- a/stat_handler.go +++ b/stat_handler.go @@ -3,51 +3,83 @@ package stats import ( "net/http" "strconv" + "sync" ) -// NewStatHandler returns an http handler for stats. -func NewStatHandler(scope Scope, handler http.Handler) http.Handler { - ret := statHandler{} - ret.scope = scope - ret.delegate = handler - ret.timer = ret.scope.NewTimer("rq_time_us") - return &ret -} +const requestTimer = "rq_time_us" -type statHandler struct { +type httpHandler struct { prefix string scope Scope delegate http.Handler - timer Timer + + timer Timer + + codes map[int]Counter + codesMtx sync.RWMutex } -type statResponseWriter struct { - handler *statHandler - delegate http.ResponseWriter - span Timespan +// NewStatHandler returns an http handler for stats. +func NewStatHandler(scope Scope, handler http.Handler) http.Handler { + return &httpHandler{ + scope: scope, + delegate: handler, + timer: scope.NewTimer(requestTimer), + codes: map[int]Counter{}, + } } -func (h *statHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - rw := &statResponseWriter{h, w, h.timer.AllocateSpan()} - h.delegate.ServeHTTP(rw, r) - rw.span.Complete() +func (h *httpHandler) counter(code int) Counter { + h.codesMtx.RLock() + c := h.codes[code] + h.codesMtx.RUnlock() + + if c != nil { + return c + } + + h.codesMtx.Lock() + if c = h.codes[code]; c == nil { + c = h.scope.NewCounter(strconv.Itoa(code)) + h.codes[code] = c + } + h.codesMtx.Unlock() + + return c } -func (h *statResponseWriter) Header() http.Header { - return h.delegate.Header() +func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + span := h.timer.AllocateSpan() + h.delegate.ServeHTTP(h.wrapResponse(w), r) + span.Complete() } -func (h *statResponseWriter) Write(b []byte) (int, error) { - return h.delegate.Write(b) +type responseWriter struct { + http.ResponseWriter + + headerWritten bool + span Timespan + handler *httpHandler } -func (h *statResponseWriter) WriteHeader(code int) { - h.handler.scope.NewCounter(strconv.Itoa(code)).Inc() - h.delegate.WriteHeader(code) +func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.headerWritten { + rw.WriteHeader(http.StatusOK) + } + return rw.ResponseWriter.Write(b) } -func (h *statResponseWriter) Flush() { - if flusher, ok := h.delegate.(http.Flusher); ok { - flusher.Flush() +func (rw *responseWriter) WriteHeader(code int) { + if rw.headerWritten { + return } + + rw.headerWritten = true + rw.handler.counter(code).Inc() + rw.ResponseWriter.WriteHeader(code) } + +var ( + _ http.Handler = (*httpHandler)(nil) + _ http.ResponseWriter = (*responseWriter)(nil) +) diff --git a/stat_handler_test.go b/stat_handler_test.go new file mode 100644 index 0000000..0df0ff3 --- /dev/null +++ b/stat_handler_test.go @@ -0,0 +1,90 @@ +package stats + +import ( + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "testing" +) + +func TestHttpHandler_ServeHTTP(t *testing.T) { + t.Parallel() + + sink := NewMockSink() + store := NewStore(sink, false) + + h := NewStatHandler( + store, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if code, err := strconv.Atoi(r.Header.Get("code")); err == nil { + w.WriteHeader(code) + } + + io.Copy(w, r.Body) + r.Body.Close() + })).(*httpHandler) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + r, _ := http.NewRequest(http.MethodGet, "/", strings.NewReader("foo")) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + store.Flush() + + if w.Body.String() != "foo" { + t.Errorf("wanted %q body, got %q", "foo", w.Body.String()) + } + + if w.Code != http.StatusOK { + t.Errorf("wanted 200, got %d", w.Code) + } + + wg.Done() + }() + + go func() { + r := httptest.NewRequest(http.MethodGet, "/", strings.NewReader("bar")) + r.Header.Set("code", strconv.Itoa(http.StatusNotFound)) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + store.Flush() + + if w.Body.String() != "bar" { + t.Errorf("wanted %q body, got %q", "bar", w.Body.String()) + } + + if w.Code != http.StatusNotFound { + t.Errorf("wanted 404, got %d", w.Code) + } + + wg.Done() + }() + + wg.Wait() + + timer, ok := sink.Timers[requestTimer] + if !ok { + t.Errorf("wanted a %q timer, none found", requestTimer) + } else if timer != 2 { + t.Error("wanted 2, got", timer) + } + + c, ok := sink.Counters["200"] + if !ok { + t.Error("wanted a '200' counter, none found") + } else if c != 1 { + t.Error("wanted 1, got", c) + } + + c, ok = sink.Counters["404"] + if !ok { + t.Error("wanted a '404' counter, none found") + } else if c != 1 { + t.Error("wanted 1, got", c) + } +} diff --git a/stat_handler_wrapper.go b/stat_handler_wrapper.go new file mode 100644 index 0000000..acd19e8 --- /dev/null +++ b/stat_handler_wrapper.go @@ -0,0 +1,113 @@ +// +build go1.8 + +package stats + +import "net/http" + +func (h *httpHandler) wrapResponse(w http.ResponseWriter) http.ResponseWriter { + rw := &responseWriter{ + ResponseWriter: w, + handler: h, + } + + flusher, canFlush := w.(http.Flusher) + hijacker, canHijack := w.(http.Hijacker) + pusher, canPush := w.(http.Pusher) + closeNotifier, canNotify := w.(http.CloseNotifier) + + if canFlush && canHijack && canPush && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.Pusher + http.CloseNotifier + }{rw, flusher, hijacker, pusher, closeNotifier} + } else if canFlush && canHijack && canPush { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.Pusher + }{rw, flusher, hijacker, pusher} + } else if canFlush && canHijack && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.CloseNotifier + }{rw, flusher, hijacker, closeNotifier} + } else if canFlush && canPush && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.Pusher + http.CloseNotifier + }{rw, flusher, pusher, closeNotifier} + } else if canHijack && canPush && canNotify { + return struct { + http.ResponseWriter + http.Hijacker + http.Pusher + http.CloseNotifier + }{rw, hijacker, pusher, closeNotifier} + } else if canFlush && canHijack { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + }{rw, flusher, hijacker} + } else if canFlush && canPush { + return struct { + http.ResponseWriter + http.Flusher + http.Pusher + }{rw, flusher, pusher} + } else if canFlush && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.CloseNotifier + }{rw, flusher, closeNotifier} + } else if canHijack && canPush { + return struct { + http.ResponseWriter + http.Hijacker + http.Pusher + }{rw, hijacker, pusher} + } else if canHijack && canNotify { + return struct { + http.ResponseWriter + http.Hijacker + http.CloseNotifier + }{rw, hijacker, closeNotifier} + } else if canPush && canNotify { + return struct { + http.ResponseWriter + http.Pusher + http.CloseNotifier + }{rw, pusher, closeNotifier} + } else if canFlush { + return struct { + http.ResponseWriter + http.Flusher + }{rw, flusher} + } else if canHijack { + return struct { + http.ResponseWriter + http.Hijacker + }{rw, hijacker} + } else if canPush { + return struct { + http.ResponseWriter + http.Pusher + }{rw, pusher} + } else if canNotify { + return struct { + http.ResponseWriter + http.CloseNotifier + }{rw, closeNotifier} + } + + return rw +} diff --git a/stat_handler_wrapper_1.7.go b/stat_handler_wrapper_1.7.go new file mode 100644 index 0000000..24c1969 --- /dev/null +++ b/stat_handler_wrapper_1.7.go @@ -0,0 +1,60 @@ +// +build !go1.8 + +package stats + +import "net/http" + +func (h *httpHandler) wrapResponse(w http.ResponseWriter) http.ResponseWriter { + rw := &responseWriter{ + ResponseWriter: w, + handler: h, + } + + flusher, canFlush := w.(http.Flusher) + hijacker, canHijack := w.(http.Hijacker) + closeNotifier, canNotify := w.(http.CloseNotifier) + + if canFlush && canHijack && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.CloseNotifier + }{rw, flusher, hijacker, closeNotifier} + } else if canFlush && canHijack { + return struct { + http.ResponseWriter + http.Flusher + http.Hijacker + }{rw, flusher, hijacker} + } else if canFlush && canNotify { + return struct { + http.ResponseWriter + http.Flusher + http.CloseNotifier + }{rw, flusher, closeNotifier} + } else if canHijack && canNotify { + return struct { + http.ResponseWriter + http.Hijacker + http.CloseNotifier + }{rw, hijacker, closeNotifier} + } else if canFlush { + return struct { + http.ResponseWriter + http.Flusher + }{rw, flusher} + } else if canHijack { + return struct { + http.ResponseWriter + http.Hijacker + }{rw, hijacker} + } else if canNotify { + return struct { + http.ResponseWriter + http.CloseNotifier + }{rw, closeNotifier} + } + + return rw +} diff --git a/stat_handler_wrapper_1.7_test.go b/stat_handler_wrapper_1.7_test.go new file mode 100644 index 0000000..9b5bdf4 --- /dev/null +++ b/stat_handler_wrapper_1.7_test.go @@ -0,0 +1,79 @@ +// +build !go1.8 + +package stats + +import ( + "fmt" + "net/http" + "testing" +) + +func TestHTTPHandler_WrapResponse(t *testing.T) { + t.Parallel() + + tests := []http.ResponseWriter{ + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + }{}, + struct { + http.ResponseWriter + http.Flusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + }{}, + struct { + http.ResponseWriter + http.Hijacker + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Hijacker + }{}, + struct { + http.ResponseWriter + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + }{}, + } + + h := NewStatHandler( + NewStore(NewNullSink(), false), + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).(*httpHandler) + + for i, test := range tests { + tc := test + t.Run(fmt.Sprint("test:", i), func(t *testing.T) { + t.Parallel() + + _, canFlush := tc.(http.Flusher) + _, canHijack := tc.(http.Hijacker) + _, canNotify := tc.(http.CloseNotifier) + + rw := h.wrapResponse(tc) + + if _, ok := rw.(http.Flusher); ok != canFlush { + t.Errorf("Flusher: wanted %t", canFlush) + } + if _, ok := rw.(http.Hijacker); ok != canHijack { + t.Errorf("Hijacker: wanted %t", canHijack) + } + if _, ok := rw.(http.CloseNotifier); ok != canNotify { + t.Errorf("CloseNotifier: wanted %t", canNotify) + } + }) + } +} diff --git a/stat_handler_wrapper_test.go b/stat_handler_wrapper_test.go new file mode 100644 index 0000000..c8b1d8d --- /dev/null +++ b/stat_handler_wrapper_test.go @@ -0,0 +1,125 @@ +// +build go1.8 + +package stats + +import ( + "fmt" + "net/http" + "testing" +) + +func TestHTTPHandler_WrapResponse(t *testing.T) { + t.Parallel() + + tests := []http.ResponseWriter{ + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.Pusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.Pusher + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Hijacker + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Pusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + http.Pusher + }{}, + struct { + http.ResponseWriter + http.Flusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Flusher + }{}, + struct { + http.ResponseWriter + http.Hijacker + http.Pusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Hijacker + http.Pusher + }{}, + struct { + http.ResponseWriter + http.Hijacker + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Hijacker + }{}, + struct { + http.ResponseWriter + http.Pusher + http.CloseNotifier + }{}, + struct { + http.ResponseWriter + http.Pusher + }{}, + struct { + http.ResponseWriter + http.CloseNotifier + }{}, + struct{ http.ResponseWriter }{}, + } + + h := NewStatHandler( + NewStore(NewNullSink(), false), + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).(*httpHandler) + + for i, test := range tests { + tc := test + t.Run(fmt.Sprint("test:", i), func(t *testing.T) { + t.Parallel() + + _, canFlush := tc.(http.Flusher) + _, canHijack := tc.(http.Hijacker) + _, canPush := tc.(http.Pusher) + _, canNotify := tc.(http.CloseNotifier) + + rw := h.wrapResponse(tc) + + if _, ok := rw.(http.Flusher); ok != canFlush { + t.Errorf("Flusher: wanted %t", canFlush) + } + if _, ok := rw.(http.Hijacker); ok != canHijack { + t.Errorf("Hijacker: wanted %t", canHijack) + } + if _, ok := rw.(http.Pusher); ok != canPush { + t.Errorf("Pusher: wanted %t", canPush) + } + if _, ok := rw.(http.CloseNotifier); ok != canNotify { + t.Errorf("CloseNotifier: wanted %t", canNotify) + } + }) + } +}