Skip to content

Commit a5e9814

Browse files
authored
Merge pull request #35 from hashicorp/f-go113-errors
Support Go 1.13 errors.As/Is/Unwrap functionality
2 parents ece20dc + 8f55492 commit a5e9814

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

multierror.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package multierror
22

33
import (
4+
"errors"
45
"fmt"
56
)
67

@@ -49,3 +50,69 @@ func (e *Error) GoString() string {
4950
func (e *Error) WrappedErrors() []error {
5051
return e.Errors
5152
}
53+
54+
// Unwrap returns an error from Error (or nil if there are no errors).
55+
// This error returned will further support Unwrap to get the next error,
56+
// etc. The order will match the order of Errors in the multierror.Error
57+
// at the time of calling.
58+
//
59+
// The resulting error supports errors.As/Is/Unwrap so you can continue
60+
// to use the stdlib errors package to introspect further.
61+
//
62+
// This will perform a shallow copy of the errors slice. Any errors appended
63+
// to this error after calling Unwrap will not be available until a new
64+
// Unwrap is called on the multierror.Error.
65+
func (e *Error) Unwrap() error {
66+
// If we have no errors then we do nothing
67+
if e == nil || len(e.Errors) == 0 {
68+
return nil
69+
}
70+
71+
// If we have exactly one error, we can just return that directly.
72+
if len(e.Errors) == 1 {
73+
return e.Errors[0]
74+
}
75+
76+
// Shallow copy the slice
77+
errs := make([]error, len(e.Errors))
78+
copy(errs, e.Errors)
79+
return chain(errs)
80+
}
81+
82+
// chain implements the interfaces necessary for errors.Is/As/Unwrap to
83+
// work in a deterministic way with multierror. A chain tracks a list of
84+
// errors while accounting for the current represented error. This lets
85+
// Is/As be meaningful.
86+
//
87+
// Unwrap returns the next error. In the cleanest form, Unwrap would return
88+
// the wrapped error here but we can't do that if we want to properly
89+
// get access to all the errors. Instead, users are recommended to use
90+
// Is/As to get the correct error type out.
91+
//
92+
// Precondition: []error is non-empty (len > 0)
93+
type chain []error
94+
95+
// Error implements the error interface
96+
func (e chain) Error() string {
97+
return e[0].Error()
98+
}
99+
100+
// Unwrap implements errors.Unwrap by returning the next error in the
101+
// chain or nil if there are no more errors.
102+
func (e chain) Unwrap() error {
103+
if len(e) == 1 {
104+
return nil
105+
}
106+
107+
return e[1:]
108+
}
109+
110+
// As implements errors.As by attempting to map to the current value.
111+
func (e chain) As(target interface{}) bool {
112+
return errors.As(e[0], target)
113+
}
114+
115+
// Is implements errors.Is by comparing the current value directly.
116+
func (e chain) Is(target error) bool {
117+
return errors.Is(e[0], target)
118+
}

multierror_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package multierror
22

33
import (
44
"errors"
5+
"fmt"
56
"reflect"
67
"testing"
78
)
@@ -69,3 +70,134 @@ func TestErrorWrappedErrors(t *testing.T) {
6970
t.Fatalf("bad: %s", multi.WrappedErrors())
7071
}
7172
}
73+
74+
func TestErrorUnwrap(t *testing.T) {
75+
t.Run("with errors", func(t *testing.T) {
76+
err := &Error{Errors: []error{
77+
errors.New("foo"),
78+
errors.New("bar"),
79+
errors.New("baz"),
80+
}}
81+
82+
var current error = err
83+
for i := 0; i < len(err.Errors); i++ {
84+
current = errors.Unwrap(current)
85+
if !errors.Is(current, err.Errors[i]) {
86+
t.Fatal("should be next value")
87+
}
88+
}
89+
90+
if errors.Unwrap(current) != nil {
91+
t.Fatal("should be nil at the end")
92+
}
93+
})
94+
95+
t.Run("with no errors", func(t *testing.T) {
96+
err := &Error{Errors: nil}
97+
if errors.Unwrap(err) != nil {
98+
t.Fatal("should be nil")
99+
}
100+
})
101+
102+
t.Run("with nil multierror", func(t *testing.T) {
103+
var err *Error
104+
if errors.Unwrap(err) != nil {
105+
t.Fatal("should be nil")
106+
}
107+
})
108+
}
109+
110+
func TestErrorIs(t *testing.T) {
111+
errBar := errors.New("bar")
112+
113+
t.Run("with errBar", func(t *testing.T) {
114+
err := &Error{Errors: []error{
115+
errors.New("foo"),
116+
errBar,
117+
errors.New("baz"),
118+
}}
119+
120+
if !errors.Is(err, errBar) {
121+
t.Fatal("should be true")
122+
}
123+
})
124+
125+
t.Run("with errBar wrapped by fmt.Errorf", func(t *testing.T) {
126+
err := &Error{Errors: []error{
127+
errors.New("foo"),
128+
fmt.Errorf("errorf: %w", errBar),
129+
errors.New("baz"),
130+
}}
131+
132+
if !errors.Is(err, errBar) {
133+
t.Fatal("should be true")
134+
}
135+
})
136+
137+
t.Run("without errBar", func(t *testing.T) {
138+
err := &Error{Errors: []error{
139+
errors.New("foo"),
140+
errors.New("baz"),
141+
}}
142+
143+
if errors.Is(err, errBar) {
144+
t.Fatal("should be false")
145+
}
146+
})
147+
}
148+
149+
func TestErrorAs(t *testing.T) {
150+
match := &nestedError{}
151+
152+
t.Run("with the value", func(t *testing.T) {
153+
err := &Error{Errors: []error{
154+
errors.New("foo"),
155+
match,
156+
errors.New("baz"),
157+
}}
158+
159+
var target *nestedError
160+
if !errors.As(err, &target) {
161+
t.Fatal("should be true")
162+
}
163+
if target == nil {
164+
t.Fatal("target should not be nil")
165+
}
166+
})
167+
168+
t.Run("with the value wrapped by fmt.Errorf", func(t *testing.T) {
169+
err := &Error{Errors: []error{
170+
errors.New("foo"),
171+
fmt.Errorf("errorf: %w", match),
172+
errors.New("baz"),
173+
}}
174+
175+
var target *nestedError
176+
if !errors.As(err, &target) {
177+
t.Fatal("should be true")
178+
}
179+
if target == nil {
180+
t.Fatal("target should not be nil")
181+
}
182+
})
183+
184+
t.Run("without the value", func(t *testing.T) {
185+
err := &Error{Errors: []error{
186+
errors.New("foo"),
187+
errors.New("baz"),
188+
}}
189+
190+
var target *nestedError
191+
if errors.As(err, &target) {
192+
t.Fatal("should be false")
193+
}
194+
if target != nil {
195+
t.Fatal("target should be nil")
196+
}
197+
})
198+
}
199+
200+
// nestedError implements error and is used for tests.
201+
type nestedError struct{}
202+
203+
func (*nestedError) Error() string { return "" }

0 commit comments

Comments
 (0)