From 74e8fa99388a3613a5cdc2272ad3c48b2c5cccdb Mon Sep 17 00:00:00 2001 From: Billy Keyes Date: Mon, 14 Mar 2022 12:24:24 -0700 Subject: [PATCH] Generalize stack formatting in errfmt.Print (#87) Instead of special-casing hatpear.PanicError in the zerolog error formatter, detect the interface it implements and format the stack trace directly. This allows other packages to implement the same interface and get the same formatting benefits. --- baseapp/error.go | 9 +-- pkg/errfmt/errfmt.go | 42 ++++++++++--- pkg/errfmt/errfmt_test.go | 122 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 16 deletions(-) create mode 100644 pkg/errfmt/errfmt_test.go diff --git a/baseapp/error.go b/baseapp/error.go index 879419cc..2920053e 100644 --- a/baseapp/error.go +++ b/baseapp/error.go @@ -16,10 +16,8 @@ package baseapp import ( "context" - "fmt" "net/http" - "github.com/bluekeyes/hatpear" "github.com/palantir/go-baseapp/pkg/errfmt" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -35,12 +33,7 @@ type httpError interface { // RichErrorMarshalFunc is a zerolog error marshaller that formats the error as // a string that includes a stack trace, if one is available. func RichErrorMarshalFunc(err error) interface{} { - switch err := err.(type) { - case hatpear.PanicError: - return fmt.Sprintf("%+v", err) - default: - return errfmt.Print(err) - } + return errfmt.Print(err) } // HandleRouteError is a hatpear error handler that logs the error and sends diff --git a/pkg/errfmt/errfmt.go b/pkg/errfmt/errfmt.go index 7d4efa6a..1a6eab5e 100644 --- a/pkg/errfmt/errfmt.go +++ b/pkg/errfmt/errfmt.go @@ -12,10 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package errfmt implements formatting for error types. Specifically, it prints +// error messages with the deepest available stacktrace for errors that include +// stacktraces. package errfmt import ( "fmt" + "runtime" + "strings" "github.com/pkg/errors" ) @@ -24,21 +29,29 @@ type causer interface { Cause() error } -type stackTracer interface { +type pkgErrorsStackTracer interface { StackTrace() errors.StackTrace } +type runtimeStackTracer interface { + StackTrace() []runtime.Frame +} + +// Print returns a string representation of err. It returns the empty string if +// err is nil. func Print(err error) string { if err == nil { return "" } - var deepestStack stackTracer + var deepestStack interface{} currErr := err for currErr != nil { - if st, ok := currErr.(stackTracer); ok { - deepestStack = st + switch currErr.(type) { + case pkgErrorsStackTracer, runtimeStackTracer: + deepestStack = currErr } + cause, ok := currErr.(causer) if !ok { break @@ -46,9 +59,22 @@ func Print(err error) string { currErr = cause.Cause() } - if deepestStack == nil { - return err.Error() - } + return err.Error() + fmtStack(deepestStack) +} - return fmt.Sprintf("%s%+v", err.Error(), deepestStack.StackTrace()) +func fmtStack(tracer interface{}) string { + switch t := tracer.(type) { + case pkgErrorsStackTracer: + return fmt.Sprintf("%+v", t.StackTrace()) + case runtimeStackTracer: + var s strings.Builder + for _, frame := range t.StackTrace() { + s.WriteByte('\n') + _, _ = fmt.Fprintf(&s, "%s\n\t", frame.Function) + _, _ = fmt.Fprintf(&s, "%s:%d", frame.File, frame.Line) + } + return s.String() + default: + return "" + } } diff --git a/pkg/errfmt/errfmt_test.go b/pkg/errfmt/errfmt_test.go new file mode 100644 index 00000000..0362c266 --- /dev/null +++ b/pkg/errfmt/errfmt_test.go @@ -0,0 +1,122 @@ +// Copyright 2022 Palantir Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package errfmt + +import ( + "errors" + "runtime" + "strings" + "testing" + + pkgerrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrint(t *testing.T) { + t.Run("nilError", func(t *testing.T) { + assert.Empty(t, Print(nil), "nil error did not product empty output") + }) + + t.Run("plainError", func(t *testing.T) { + err := errors.New("this is an error") + assert.Equal(t, "this is an error", Print(err), "incorrect error output") + }) + + t.Run("nestedError", func(t *testing.T) { + root := errors.New("this is an error") + err := pkgerrors.WithMessage(root, "context 1") + err = pkgerrors.WithMessage(err, "context 2") + err = pkgerrors.WithMessage(err, "context 3") + assert.Equal(t, "context 3: context 2: context 1: this is an error", Print(err), "incorrect error output") + }) + + t.Run("pkgErrorsStackTrace", func(t *testing.T) { + const depth = 3 + const minLines = 1 + 2*(depth+1) + + err := recursiveError( + depth, + func() error { return errors.New("this is an error") }, + func(err error) error { return pkgerrors.Wrap(err, "context") }, + ) + + out := Print(err) + t.Log(out) + + outLines := strings.Split(out, "\n") + require.True(t, len(outLines) > minLines, "expected at least %d error lines, but got %d", minLines, len(outLines)) + + assert.Equal(t, "context: context: context: this is an error", outLines[0], "incorrect error message") + assert.Contains(t, outLines[3], "errfmt.recursiveError", "incorrect stack trace") + assert.Contains(t, outLines[5], "errfmt.recursiveError", "incorrect stack trace") + assert.Contains(t, outLines[7], "errfmt.recursiveError", "incorrect stack trace") + }) + + t.Run("runtimeStackTrace", func(t *testing.T) { + const depth = 3 + const minLines = 1 + 2*(depth+1) + + err := recursiveError( + depth, + func() error { return newStackTraceError("this is an error") }, + func(err error) error { return err }, + ) + + out := Print(err) + t.Log(out) + + outLines := strings.Split(out, "\n") + require.True(t, len(outLines) > minLines, "expected at least %d error lines, but got %d", minLines, len(outLines)) + + assert.Equal(t, "this is an error", outLines[0], "incorrect error message") + assert.Contains(t, outLines[3], "errfmt.recursiveError", "incorrect stack trace") + assert.Contains(t, outLines[5], "errfmt.recursiveError", "incorrect stack trace") + assert.Contains(t, outLines[7], "errfmt.recursiveError", "incorrect stack trace") + }) +} + +func recursiveError(depth int, root func() error, wrap func(error) error) error { + if depth == 0 { + return root() + } + return wrap(recursiveError(depth-1, root, wrap)) +} + +type stackTraceError struct { + msg string + st []runtime.Frame +} + +func newStackTraceError(msg string) error { + callers := make([]uintptr, 32) + + n := runtime.Callers(2, callers) + frames := runtime.CallersFrames(callers[0:n]) + + var stack []runtime.Frame + for { + f, more := frames.Next() + if !more { + break + } + stack = append(stack, f) + } + + return stackTraceError{msg: msg, st: stack} +} + +func (ste stackTraceError) Error() string { return ste.msg } +func (ste stackTraceError) StackTrace() []runtime.Frame { return ste.st }