Skip to content

Commit

Permalink
Merge pull request #25 from go-viper/error
Browse files Browse the repository at this point in the history
Replace internal joined error with errors.Join
  • Loading branch information
sagikazarmark authored Jun 2, 2024
2 parents 57a3d74 + 9e25c61 commit 35d054a
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 105 deletions.
50 changes: 0 additions & 50 deletions error.go

This file was deleted.

11 changes: 11 additions & 0 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package errors

import "errors"

func New(text string) error {
return errors.New(text)
}

func As(err error, target interface{}) bool {
return errors.As(err, target)
}
9 changes: 9 additions & 0 deletions internal/errors/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build go1.20

package errors

import "errors"

func Join(errs ...error) error {
return errors.Join(errs...)
}
61 changes: 61 additions & 0 deletions internal/errors/join_go1_19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//go:build !go1.20

// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package errors

// Join returns an error that wraps the given errors.
// Any nil error values are discarded.
// Join returns nil if every value in errs is nil.
// The error formats as the concatenation of the strings obtained
// by calling the Error method of each element of errs, with a newline
// between each string.
//
// A non-nil error returned by Join implements the Unwrap() []error method.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

type joinError struct {
errs []error
}

func (e *joinError) Error() string {
// Since Join returns nil if every value in errs is nil,
// e.errs cannot be empty.
if len(e.errs) == 1 {
return e.errs[0].Error()
}

b := []byte(e.errs[0].Error())
for _, err := range e.errs[1:] {
b = append(b, '\n')
b = append(b, err.Error()...)
}
// At this point, b has at least one byte '\n'.
// return unsafe.String(&b[0], len(b))
return string(b)
}

func (e *joinError) Unwrap() []error {
return e.errs
}
66 changes: 30 additions & 36 deletions mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ package mapstructure

import (
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"

"github.com/go-viper/mapstructure/v2/internal/errors"
)

// DecodeHookFunc is the callback function that can be used for
Expand Down Expand Up @@ -414,7 +415,15 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) {
// Decode decodes the given raw interface to the target pointer specified
// by the configuration.
func (d *Decoder) Decode(input interface{}) error {
return d.decode("", input, reflect.ValueOf(d.config.Result).Elem())
err := d.decode("", input, reflect.ValueOf(d.config.Result).Elem())

// Retain some of the original behavior when multiple errors ocurr
var joinedErr interface{ Unwrap() []error }
if errors.As(err, &joinedErr) {
return fmt.Errorf("decoding failed due to the following error(s):\n\n%w", err)
}

return err
}

// Decodes an unknown data type into a specific reflection value.
Expand Down Expand Up @@ -881,7 +890,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
valElemType := valType.Elem()

// Accumulate errors
errors := make([]string, 0)
var errs []error

// If the input data is empty, then we just match what the input data is.
if dataVal.Len() == 0 {
Expand All @@ -903,15 +912,15 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
// First decode the key into the proper type
currentKey := reflect.Indirect(reflect.New(valKeyType))
if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
continue
}

// Next decode the data into the proper type
v := dataVal.MapIndex(k).Interface()
currentVal := reflect.Indirect(reflect.New(valElemType))
if err := d.decode(fieldName, v, currentVal); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
continue
}

Expand All @@ -921,12 +930,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
// Set the built up map to the value
val.Set(valMap)

// If we had errors, return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
Expand Down Expand Up @@ -1164,7 +1168,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
var errs []error

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
Expand All @@ -1175,19 +1179,14 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

// Finally, set the value to the slice we built up
val.Set(valSlice)

// If there were errors, we return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error {
Expand Down Expand Up @@ -1233,27 +1232,22 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
var errs []error

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
currentField := valArray.Index(i)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

// Finally, set the value to the array we built up
val.Set(valArray)

// If there were errors, we return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error {
Expand Down Expand Up @@ -1315,7 +1309,8 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

targetValKeysUnused := make(map[interface{}]struct{})
errors := make([]string, 0)

var errs []error

// This slice will keep track of all the structs we'll be decoding.
// There can be more than one struct if there are embedded structs
Expand Down Expand Up @@ -1369,8 +1364,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

if squash {
if fieldVal.Kind() != reflect.Struct {
errors = appendErrors(errors,
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
} else {
structs = append(structs, fieldVal)
}
Expand Down Expand Up @@ -1449,7 +1443,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

Expand All @@ -1464,7 +1458,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

// Decode it as-if we were just decoding this map onto our map.
if err := d.decodeMap(name, remain, remainField.val); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}

// Set the map to nil so we have none so that the next check will
Expand All @@ -1480,7 +1474,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = append(errs, err)
}

if d.config.ErrorUnset && len(targetValKeysUnused) > 0 {
Expand All @@ -1491,11 +1485,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = append(errs, err)
}

if len(errors) > 0 {
return &joinedError{errors}
if err := errors.Join(errs...); err != nil {
return err
}

// Add the unused keys to the list of unused keys if we're tracking metadata
Expand Down
12 changes: 6 additions & 6 deletions mapstructure_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ func ExampleDecode_errors() {

fmt.Println(err.Error())
// Output:
// 5 error(s) decoding:
// decoding failed due to the following error(s):
//
// * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
// * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
// * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
// * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
// * 'Name' expected type 'string', got unconvertible type 'int', value: '123'
// 'Name' expected type 'string', got unconvertible type 'int', value: '123'
// 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
// 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
// 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
// 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
}

func ExampleDecode_metadata() {
Expand Down
Loading

0 comments on commit 35d054a

Please sign in to comment.