Skip to content

Commit

Permalink
Merge pull request #112 from goccy/feature/support-strict-option
Browse files Browse the repository at this point in the history
Support DisallowDuplicateKey and Strict option for Decoder
  • Loading branch information
goccy authored Jun 1, 2020
2 parents b0455d0 + e95f2fb commit aefe2e5
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
47 changes: 46 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Decoder struct {
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
useOrderedMap bool
parsedFile *ast.File
streamIndex int
Expand All @@ -52,6 +53,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
useOrderedMap: false,
}
}
Expand Down Expand Up @@ -347,6 +349,18 @@ func errUnknownField(msg string, tk *token.Token) *unknownFieldError {
return &unknownFieldError{err: errors.ErrSyntax(msg, tk)}
}

type duplicateKeyError struct {
err error
}

func (e *duplicateKeyError) Error() string {
return e.err.Error()
}

func errDuplicateKey(msg string, tk *token.Token) *duplicateKeyError {
return &duplicateKeyError{err: errors.ErrSyntax(msg, tk)}
}

func (d *Decoder) deleteStructKeys(structValue reflect.Value, unknownFields map[string]ast.Node) error {
strType := structValue.Type()
structFieldMap, err := structFieldMap(strType)
Expand Down Expand Up @@ -953,6 +967,20 @@ func (d *Decoder) decodeMapItem(dst *MapItem, src ast.Node) error {
return nil
}

func (d *Decoder) validateMapKey(keyMap map[string]struct{}, key interface{}, keyNode ast.Node) error {
k, ok := key.(string)
if !ok {
return nil
}
if d.disallowDuplicateKey {
if _, exists := keyMap[k]; exists {
return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken())
}
}
keyMap[k] = struct{}{}
return nil
}

func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error {
mapNode, err := d.getMapNode(src)
if err != nil {
Expand All @@ -963,6 +991,7 @@ func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error {
}
mapSlice := MapSlice{}
mapIter := mapNode.MapRange()
keyMap := map[string]struct{}{}
for mapIter.Next() {
key := mapIter.Key()
value := mapIter.Value()
Expand All @@ -972,12 +1001,19 @@ func (d *Decoder) decodeMapSlice(dst *MapSlice, src ast.Node) error {
return errors.Wrapf(err, "failed to decode map with merge key")
}
for _, v := range m {
if err := d.validateMapKey(keyMap, v.Key, value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, v)
}
continue
}
k := d.nodeToValue(key)
if err := d.validateMapKey(keyMap, k, key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapSlice = append(mapSlice, MapItem{
Key: d.nodeToValue(key),
Key: k,
Value: d.nodeToValue(value),
})
}
Expand All @@ -998,6 +1034,7 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
keyType := mapValue.Type().Key()
valueType := mapValue.Type().Elem()
mapIter := mapNode.MapRange()
keyMap := map[string]struct{}{}
var foundErr error
for mapIter.Next() {
key := mapIter.Key()
Expand All @@ -1008,6 +1045,9 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
}
iter := dst.MapRange()
for iter.Next() {
if err := d.validateMapKey(keyMap, iter.Key(), value); err != nil {
return errors.Wrapf(err, "invalid map key")
}
mapValue.SetMapIndex(iter.Key(), iter.Value())
}
continue
Expand All @@ -1016,6 +1056,11 @@ func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
if k.IsValid() && k.Type().ConvertibleTo(keyType) {
k = k.Convert(keyType)
}
if k.IsValid() {
if err := d.validateMapKey(keyMap, k.Interface(), key); err != nil {
return errors.Wrapf(err, "invalid map key")
}
}
if valueType.Kind() == reflect.Ptr && value.Type() == ast.NullType {
// set nil value to pointer
mapValue.SetMapIndex(k, reflect.Zero(valueType))
Expand Down
20 changes: 20 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,26 @@ children:
})
}

func TestDecoder_DisallowDuplicateKey(t *testing.T) {
yml := `
a: b
a: c
`
expected := `
[3:1] duplicate key "a"
2 |
> 3 | a: b
4 | a: c
^
`
var v map[string]string
err := yaml.NewDecoder(strings.NewReader(yml), yaml.DisallowDuplicateKey()).Decode(&v)
actual := "\n" + err.Error()
if expected != actual {
t.Fatalf("expected:[%s] actual:[%s]", expected, actual)
}
}

func TestDecoder_DefaultValues(t *testing.T) {
v := struct {
A string `yaml:"a"`
Expand Down
17 changes: 17 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ func Validator(v StructValidator) DecodeOption {
}
}

// Strict enable DisallowUnknownField and DisallowDuplicateKey
func Strict() DecodeOption {
return func(d *Decoder) error {
d.disallowUnknownField = true
d.disallowDuplicateKey = true
return nil
}
}

// DisallowUnknownField causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
Expand All @@ -59,6 +68,14 @@ func DisallowUnknownField() DecodeOption {
}
}

// DisallowDuplicateKey causes an error when mapping keys that are duplicates
func DisallowDuplicateKey() DecodeOption {
return func(d *Decoder) error {
d.disallowDuplicateKey = true
return nil
}
}

// UseOrderedMap can be interpreted as a map,
// and uses MapSlice ( ordered map ) aggressively if there is no type specification
func UseOrderedMap() DecodeOption {
Expand Down

0 comments on commit aefe2e5

Please sign in to comment.