From bb14b8b9de1133342fce9aff0158b354ec72f08c Mon Sep 17 00:00:00 2001 From: Mitar Date: Wed, 8 Nov 2023 12:04:17 -0800 Subject: [PATCH] Add additional hlog logging handlers (#607) * Add HTTPVersionHandler. * Add RemoteIPHandler. * Add trimPort to HostHandler. * Add EtagHandler. * Add ResponseHeaderHandler. * Add TestGetHost. * Call AccessHandler's f also on panic. --- hlog/hlog.go | 108 +++++++++++++++++++++++++++++++++++--- hlog/hlog_test.go | 129 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 8 deletions(-) diff --git a/hlog/hlog.go b/hlog/hlog.go index e0dd512a..06ca4adf 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -3,7 +3,9 @@ package hlog import ( "context" + "net" "net/http" + "strings" "time" "github.com/rs/xid" @@ -89,6 +91,35 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { } } +func getHost(hostPort string) string { + if hostPort == "" { + return "" + } + + host, _, err := net.SplitHostPort(hostPort) + if err != nil { + return hostPort + } + return host +} + +// RemoteIPHandler is similar to RemoteAddrHandler, but logs only +// an IP, not a port. +func RemoteIPHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getHost(r.RemoteAddr) + if ip != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, ip) + }) + } + next.ServeHTTP(w, r) + }) + } +} + // UserAgentHandler adds the request's user-agent as a field to the context's logger // using fieldKey as field key. func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler { @@ -135,6 +166,21 @@ func ProtoHandler(fieldKey string) func(next http.Handler) http.Handler { } } +// HTTPVersionHandler is similar to ProtoHandler, but it does not store the "HTTP/" +// prefix in the protocol name. +func HTTPVersionHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proto := strings.TrimPrefix(r.Proto, "HTTP/") + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, proto) + }) + next.ServeHTTP(w, r) + }) + } +} + type idKey struct{} // IDFromRequest returns the unique id associated to the request if any. @@ -205,27 +251,75 @@ func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.H } } +// EtagHandler adds Etag header from response's header as a field to +// the context's logger using fieldKey as field key. +func EtagHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + etag := w.Header().Get("Etag") + if etag != "" { + etag = strings.ReplaceAll(etag, `"`, "") + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, etag) + }) + } + }() + next.ServeHTTP(w, r) + }) + } +} + +func ResponseHeaderHandler(fieldKey, headerName string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + value := w.Header().Get(headerName) + if value != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, value) + }) + } + }() + next.ServeHTTP(w, r) + }) + } +} + // AccessHandler returns a handler that call f after each request. func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lw := mutil.WrapWriter(w) + defer func() { + f(r, lw.Status(), lw.BytesWritten(), time.Since(start)) + }() next.ServeHTTP(lw, r) - f(r, lw.Status(), lw.BytesWritten(), time.Since(start)) }) } } // HostHandler adds the request's host as a field to the context's logger -// using fieldKey as field key. -func HostHandler(fieldKey string) func(next http.Handler) http.Handler { +// using fieldKey as field key. If trimPort is set to true, then port is +// removed from the host. +func HostHandler(fieldKey string, trimPort ...bool) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := zerolog.Ctx(r.Context()) - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str(fieldKey, r.Host) - }) + var host string + if len(trimPort) > 0 && trimPort[0] { + host = getHost(r.Host) + } else { + host = r.Host + } + if host != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, host) + }) + } next.ServeHTTP(w, r) }) } diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index 1f5a1bcd..3dc7d317 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -122,6 +122,38 @@ func TestRemoteAddrHandlerIPv6(t *testing.T) { } } +func TestRemoteIPHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + RemoteAddr: "1.2.3.4:1234", + } + h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestRemoteIPHandlerIPv6(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + RemoteAddr: "[2001:db8:a0b:12f0::1]:1234", + } + h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestUserAgentHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -201,6 +233,46 @@ func TestCustomHeaderHandler(t *testing.T) { } } +func TestEtagHandler(t *testing.T) { + out := &bytes.Buffer{} + w := httptest.NewRecorder() + r := &http.Request{} + h := EtagHandler("etag")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Etag", `"abcdef"`) + w.WriteHeader(http.StatusOK) + })) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + l := FromRequest(r) + l.Log().Msg("") + }) + h3 := NewHandler(zerolog.New(out))(h2) + h3.ServeHTTP(w, r) + if want, got := `{"etag":"abcdef"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestResponseHeaderHandler(t *testing.T) { + out := &bytes.Buffer{} + w := httptest.NewRecorder() + r := &http.Request{} + h := ResponseHeaderHandler("encoding", "Content-Encoding")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", `gzip`) + w.WriteHeader(http.StatusOK) + })) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + l := FromRequest(r) + l.Log().Msg("") + }) + h3 := NewHandler(zerolog.New(out))(h2) + h3.ServeHTTP(w, r) + if want, got := `{"encoding":"gzip"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestProtoHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -217,6 +289,22 @@ func TestProtoHandler(t *testing.T) { } } +func TestHTTPVersionHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + Proto: "HTTP/1.1", + } + h := HTTPVersionHandler("proto")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"proto":"1.1"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestCombinedHandlers(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -295,14 +383,53 @@ func TestCtxWithID(t *testing.T) { func TestHostHandler(t *testing.T) { out := &bytes.Buffer{} - r := &http.Request{Host: "example.com"} + r := &http.Request{Host: "example.com:8080"} h := HostHandler("host")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"host":"example.com:8080"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestHostHandlerWithoutPort(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{Host: "example.com:8080"} + h := HostHandler("host", true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) if want, got := `{"host":"example.com"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } + +func TestGetHost(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"example.com:8080", "example.com"}, + {"example.com", "example.com"}, + {"invalid", "invalid"}, + {"192.168.0.1:8080", "192.168.0.1"}, + {"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + {"こんにちは.com:8080", "こんにちは.com"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + result := getHost(tt.input) + if tt.expected != result { + t.Errorf("Invalid log output, got: %s, want: %s", result, tt.expected) + } + }) + } +}