diff --git a/src/errors/join.go b/src/errors/join.go index 349fc06ed9f75c..fc3f5925fdad94 100644 --- a/src/errors/join.go +++ b/src/errors/join.go @@ -17,24 +17,33 @@ import ( // // A non-nil error returned by Join implements the Unwrap() []error method. func Join(errs ...error) error { - n := 0 - for _, err := range errs { - if err != nil { - n++ + allErrs := make([]error, 0, len(errs)) + for _, e := range errs { + // Ignore nil errors. + if e == nil { + continue } + + // Specifically handle nested join errors from the standard library. This + // avoids deeply-nesting values which can be unexpected when unwrapping + // errors. + joinErr, ok := e.(*joinError) + if !ok { + allErrs = append(allErrs, e) + continue + } + + allErrs = append(allErrs, joinErr.errs...) } - if n == 0 { + + // Ensure we return nil if all contained errors were nil. + if len(allErrs) == 0 { return nil } - e := &joinError{ - errs: make([]error, 0, n), - } - for _, err := range errs { - if err != nil { - e.errs = append(e.errs, err) - } + + return &joinError{ + errs: allErrs, } - return e } type joinError struct { diff --git a/src/errors/join_test.go b/src/errors/join_test.go index 4828dc4d755fd6..8e580fab26542b 100644 --- a/src/errors/join_test.go +++ b/src/errors/join_test.go @@ -37,6 +37,9 @@ func TestJoin(t *testing.T) { }, { errs: []error{err1, nil, err2}, want: []error{err1, err2}, + }, { + errs: []error{errors.Join(err1, err2), err1, err2}, + want: []error{err1, err2, err1, err2}, }} { got := errors.Join(test.errs...).(interface{ Unwrap() []error }).Unwrap() if !reflect.DeepEqual(got, test.want) {