From a15f0c34add936c91b7a317823d81b11fb73f7c0 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Sun, 11 Dec 2016 16:19:35 +0100 Subject: [PATCH] transport/http: add ServerFinalizer --- transport/http/server.go | 35 ++++++++++++++++++++++++++++++++++- transport/http/server_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/transport/http/server.go b/transport/http/server.go index db528bef6..0f2e64372 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -18,6 +18,7 @@ type Server struct { before []RequestFunc after []ServerResponseFunc errorEncoder ErrorEncoder + finalizer ServerFinalizerFunc logger log.Logger } @@ -69,15 +70,30 @@ func ServerErrorEncoder(ee ErrorEncoder) ServerOption { } // ServerErrorLogger is used to log non-terminal errors. By default, no errors -// are logged. +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. func ServerErrorLogger(logger log.Logger) ServerOption { return func(s *Server) { s.logger = logger } } +// ServerFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ServerFinalizer(f ServerFinalizerFunc) ServerOption { + return func(s *Server) { s.finalizer = f } +} + // ServeHTTP implements http.Handler. func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := s.ctx + if s.finalizer != nil { + iw := &interceptingWriter{w, http.StatusOK} + defer func() { s.finalizer(ctx, iw.code, r) }() + w = iw + } + for _, f := range s.before { ctx = f(ctx, r) } @@ -116,6 +132,11 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // for their own error types. See the example shipping/handling service. type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) +// ServerFinalizerFunc can be used to perform work at the end of an HTTP +// request, after the response has been written to the client. The principal +// intended use is for request logging. +type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) + func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { switch e := err.(type) { case Error: @@ -131,3 +152,15 @@ func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { http.Error(w, err.Error(), http.StatusInternalServerError) } } + +type interceptingWriter struct { + http.ResponseWriter + code int +} + +// WriteHeader may not be explicitly called, so care must be taken to +// initialize w.code to its default value of http.StatusOK. +func (w *interceptingWriter) WriteHeader(code int) { + w.code = code + w.ResponseWriter.WriteHeader(code) +} diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 752f010d8..0fd0bb598 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -9,6 +9,7 @@ import ( "golang.org/x/net/context" + "github.com/go-kit/kit/endpoint" httptransport "github.com/go-kit/kit/transport/http" ) @@ -91,6 +92,36 @@ func TestServerHappyPath(t *testing.T) { } } +func TestServerFinalizer(t *testing.T) { + c := make(chan int) + handler := httptransport.NewServer( + context.Background(), + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.WriteHeader(<-c) + return nil + }, + httptransport.ServerFinalizer(func(_ context.Context, code int, _ *http.Request) { + c <- code + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + want := http.StatusTeapot + c <- want // give status code to response encoder + have := <-c // take status code from finalizer + + if want != have { + t.Errorf("want %d, have %d", want, have) + } +} + func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) { var ( ctx, cancelfn = context.WithCancel(context.Background())