diff --git a/error.go b/error.go deleted file mode 100644 index 65624fc..0000000 --- a/error.go +++ /dev/null @@ -1,50 +0,0 @@ -package mapstructure - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// joinedError implements the error interface and can represents multiple -// errors that occur in the course of a single decode. -type joinedError struct { - Errors []string -} - -func (e *joinedError) Error() string { - points := make([]string, len(e.Errors)) - for i, err := range e.Errors { - points[i] = fmt.Sprintf("* %s", err) - } - - sort.Strings(points) - return fmt.Sprintf( - "%d error(s) decoding:\n\n%s", - len(e.Errors), strings.Join(points, "\n")) -} - -// Unwrap implements the Unwrap function added in Go 1.20. -func (e *joinedError) Unwrap() []error { - if e == nil { - return nil - } - - result := make([]error, len(e.Errors)) - for i, e := range e.Errors { - result[i] = errors.New(e) - } - - return result -} - -// TODO: replace with errors.Join when Go 1.20 is minimum version. -func appendErrors(errors []string, err error) []string { - switch e := err.(type) { - case *joinedError: - return append(errors, e.Errors...) - default: - return append(errors, e.Error()) - } -} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..d1c15e4 --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,11 @@ +package errors + +import "errors" + +func New(text string) error { + return errors.New(text) +} + +func As(err error, target interface{}) bool { + return errors.As(err, target) +} diff --git a/internal/errors/join.go b/internal/errors/join.go new file mode 100644 index 0000000..d74e3a0 --- /dev/null +++ b/internal/errors/join.go @@ -0,0 +1,9 @@ +//go:build go1.20 + +package errors + +import "errors" + +func Join(errs ...error) error { + return errors.Join(errs...) +} diff --git a/internal/errors/join_go1_19.go b/internal/errors/join_go1_19.go new file mode 100644 index 0000000..700b402 --- /dev/null +++ b/internal/errors/join_go1_19.go @@ -0,0 +1,61 @@ +//go:build !go1.20 + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errors + +// Join returns an error that wraps the given errors. +// Any nil error values are discarded. +// Join returns nil if every value in errs is nil. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +// +// A non-nil error returned by Join implements the Unwrap() []error method. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + // Since Join returns nil if every value in errs is nil, + // e.errs cannot be empty. + if len(e.errs) == 1 { + return e.errs[0].Error() + } + + b := []byte(e.errs[0].Error()) + for _, err := range e.errs[1:] { + b = append(b, '\n') + b = append(b, err.Error()...) + } + // At this point, b has at least one byte '\n'. + // return unsafe.String(&b[0], len(b)) + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/mapstructure.go b/mapstructure.go index 34ae3a4..4b54fae 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -160,12 +160,13 @@ package mapstructure import ( "encoding/json" - "errors" "fmt" "reflect" "sort" "strconv" "strings" + + "github.com/go-viper/mapstructure/v2/internal/errors" ) // DecodeHookFunc is the callback function that can be used for @@ -414,7 +415,15 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) { // Decode decodes the given raw interface to the target pointer specified // by the configuration. func (d *Decoder) Decode(input interface{}) error { - return d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) + err := d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) + + // Retain some of the original behavior when multiple errors ocurr + var joinedErr interface{ Unwrap() []error } + if errors.As(err, &joinedErr) { + return fmt.Errorf("decoding failed due to the following error(s):\n\n%w", err) + } + + return err } // Decodes an unknown data type into a specific reflection value. @@ -881,7 +890,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle valElemType := valType.Elem() // Accumulate errors - errors := make([]string, 0) + var errs []error // If the input data is empty, then we just match what the input data is. if dataVal.Len() == 0 { @@ -903,7 +912,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle // First decode the key into the proper type currentKey := reflect.Indirect(reflect.New(valKeyType)) if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) continue } @@ -911,7 +920,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle v := dataVal.MapIndex(k).Interface() currentVal := reflect.Indirect(reflect.New(valElemType)) if err := d.decode(fieldName, v, currentVal); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) continue } @@ -921,12 +930,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle // Set the built up map to the value val.Set(valMap) - // If we had errors, return those - if len(errors) > 0 { - return &joinedError{errors} - } - - return nil + return errors.Join(errs...) } func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { @@ -1164,7 +1168,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) } // Accumulate any errors - errors := make([]string, 0) + var errs []error for i := 0; i < dataVal.Len(); i++ { currentData := dataVal.Index(i).Interface() @@ -1175,19 +1179,14 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) fieldName := name + "[" + strconv.Itoa(i) + "]" if err := d.decode(fieldName, currentData, currentField); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) } } // Finally, set the value to the slice we built up val.Set(valSlice) - // If there were errors, we return those - if len(errors) > 0 { - return &joinedError{errors} - } - - return nil + return errors.Join(errs...) } func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error { @@ -1233,7 +1232,7 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) } // Accumulate any errors - errors := make([]string, 0) + var errs []error for i := 0; i < dataVal.Len(); i++ { currentData := dataVal.Index(i).Interface() @@ -1241,19 +1240,14 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) fieldName := name + "[" + strconv.Itoa(i) + "]" if err := d.decode(fieldName, currentData, currentField); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) } } // Finally, set the value to the array we built up val.Set(valArray) - // If there were errors, we return those - if len(errors) > 0 { - return &joinedError{errors} - } - - return nil + return errors.Join(errs...) } func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error { @@ -1315,7 +1309,8 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } targetValKeysUnused := make(map[interface{}]struct{}) - errors := make([]string, 0) + + var errs []error // This slice will keep track of all the structs we'll be decoding. // There can be more than one struct if there are embedded structs @@ -1369,8 +1364,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e if squash { if fieldVal.Kind() != reflect.Struct { - errors = appendErrors(errors, - fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) + errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) } else { structs = append(structs, fieldVal) } @@ -1449,7 +1443,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) } } @@ -1464,7 +1458,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e // Decode it as-if we were just decoding this map onto our map. if err := d.decodeMap(name, remain, remainField.val); err != nil { - errors = appendErrors(errors, err) + errs = append(errs, err) } // Set the map to nil so we have none so that the next check will @@ -1480,7 +1474,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e sort.Strings(keys) err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", ")) - errors = appendErrors(errors, err) + errs = append(errs, err) } if d.config.ErrorUnset && len(targetValKeysUnused) > 0 { @@ -1491,11 +1485,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e sort.Strings(keys) err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", ")) - errors = appendErrors(errors, err) + errs = append(errs, err) } - if len(errors) > 0 { - return &joinedError{errors} + if err := errors.Join(errs...); err != nil { + return err } // Add the unused keys to the list of unused keys if we're tracking metadata diff --git a/mapstructure_examples_test.go b/mapstructure_examples_test.go index 31cff15..a8735d4 100644 --- a/mapstructure_examples_test.go +++ b/mapstructure_examples_test.go @@ -63,13 +63,13 @@ func ExampleDecode_errors() { fmt.Println(err.Error()) // Output: - // 5 error(s) decoding: + // decoding failed due to the following error(s): // - // * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value' - // * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1' - // * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2' - // * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3' - // * 'Name' expected type 'string', got unconvertible type 'int', value: '123' + // 'Name' expected type 'string', got unconvertible type 'int', value: '123' + // 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value' + // 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1' + // 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2' + // 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3' } func ExampleDecode_metadata() { diff --git a/mapstructure_test.go b/mapstructure_test.go index c33d426..4501592 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -2,6 +2,7 @@ package mapstructure import ( "encoding/json" + "errors" "io" "reflect" "sort" @@ -2323,13 +2324,17 @@ func TestInvalidType(t *testing.T) { t.Fatal("error should exist") } - derr, ok := err.(*joinedError) - if !ok { - t.Fatalf("error should be kind of joinedError, instead: %#v", err) + var derr interface { + Unwrap() []error + } + + if !errors.As(err, &derr) { + t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err) } - if derr.Errors[0] != - "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" { + errs := derr.Unwrap() + + if errs[0].Error() != "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" { t.Errorf("got unexpected error: %s", err) } @@ -2342,12 +2347,13 @@ func TestInvalidType(t *testing.T) { t.Fatal("error should exist") } - derr, ok = err.(*joinedError) - if !ok { - t.Fatalf("error should be kind of joinedError, instead: %#v", err) + if !errors.As(err, &derr) { + t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err) } - if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" { + errs = derr.Unwrap() + + if errs[0].Error() != "cannot parse 'Vuint', -42 overflows uint" { t.Errorf("got unexpected error: %s", err) } @@ -2360,12 +2366,13 @@ func TestInvalidType(t *testing.T) { t.Fatal("error should exist") } - derr, ok = err.(*joinedError) - if !ok { - t.Fatalf("error should be kind of joinedError, instead: %#v", err) + if !errors.As(err, &derr) { + t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err) } - if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" { + errs = derr.Unwrap() + + if errs[0].Error() != "cannot parse 'Vuint', -42.000000 overflows uint" { t.Errorf("got unexpected error: %s", err) } }