Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode: fix reuse of slice for array tables #934

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions internal/tracker/seen.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {

// CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
// consistent. It returns true if it is the first time this node's key is seen.
// Useful to clear array tables on first use.
func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) {
if s.entries == nil {
s.reset()
}
Expand All @@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
}
}

func (s *SeenTracker) checkTable(node *unstable.Node) error {
func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
Expand All @@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
}
parentIdx = idx
Expand All @@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
k := it.Node().Data
idx := s.find(parentIdx, k)

first := false
if idx >= 0 {
kind := s.entries[idx].kind
if kind != tableKind {
return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
}
if s.entries[idx].explicit {
return fmt.Errorf("toml: table %s already exists", string(k))
return false, fmt.Errorf("toml: table %s already exists", string(k))
}
s.entries[idx].explicit = true
} else {
idx = s.create(parentIdx, k, tableKind, true, false)
first = true
}

s.currentIdx = idx

return nil
return first, nil
}

func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
Expand All @@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
}

Expand All @@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
k := it.Node().Data
idx := s.find(parentIdx, k)

if idx >= 0 {
firstTime := idx < 0
if firstTime {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
} else {
kind := s.entries[idx].kind
if kind != arrayTableKind {
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
}
s.clear(idx)
} else {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
}

s.currentIdx = idx

return nil
return firstTime, nil
}

func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
parentIdx := s.currentIdx
it := node.Key()

Expand All @@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if it.IsLast() {
return fmt.Errorf("toml: key %s is already defined", string(k))
return false, fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit {
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
}
}

Expand All @@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
return s.checkArray(value)
}

return nil
return false, nil
}

func (s *SeenTracker) checkArray(node *unstable.Node) error {
func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
it := node.Children()
for it.Next() {
n := it.Node()
switch n.Kind {
case unstable.InlineTable:
err := s.checkInlineTable(n)
first, err = s.checkInlineTable(n)
if err != nil {
return err
return false, err
}
case unstable.Array:
err := s.checkArray(n)
first, err = s.checkArray(n)
if err != nil {
return err
return false, err
}
}
}
return nil
return first, nil
}

func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) {
if pool.New == nil {
pool.New = func() interface{} {
return &SeenTracker{}
Expand All @@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(n)
first, err = s.checkKeyValue(n)
if err != nil {
return err
return false, err
}
}

Expand All @@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
// redefinition of its keys: check* functions cannot walk into
// a value.
pool.Put(s)
return nil
return first, nil
}
18 changes: 16 additions & 2 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ type decoder struct {
// need to be skipped.
skipUntilTable bool

// Flag indicating that the current array/slice table should be cleared because
// it is the first encounter of an array table.
clearArrayTable bool

// Tracks position in Go arrays.
// This is used when decoding [[array tables]] into Go arrays. Given array
// tables are separate TOML expression, we need to keep track of where we
Expand Down Expand Up @@ -246,9 +250,10 @@ Rules for the unmarshal code:
func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
var x reflect.Value
var err error
var first bool // used for to clear array tables on first use

if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
err = d.seen.CheckExpression(expr)
first, err = d.seen.CheckExpression(expr)
if err != nil {
return err
}
Expand All @@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
case unstable.ArrayTable:
d.skipUntilTable = false
d.strict.EnterArrayTable(expr)
d.clearArrayTable = first
x, err = d.handleArrayTable(expr.Key(), v)
default:
panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
Expand Down Expand Up @@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
reflect.Copy(nelem, elem)
elem = nelem
}
if d.clearArrayTable && elem.Len() > 0 {
elem.SetLen(0)
d.clearArrayTable = false
}
}
return d.handleArrayTableCollectionLast(key, elem)
case reflect.Ptr:
Expand All @@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec

return v, nil
case reflect.Slice:
if d.clearArrayTable && v.Len() > 0 {
v.SetLen(0)
d.clearArrayTable = false
}
elemType := v.Type().Elem()
var elem reflect.Value
if elemType.Kind() == reflect.Interface {
Expand Down Expand Up @@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
break
}

err := d.seen.CheckExpression(expr)
_, err := d.seen.CheckExpression(expr)
if err != nil {
return reflect.Value{}, err
}
Expand Down
70 changes: 70 additions & 0 deletions unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2823,6 +2823,76 @@ blah.a = "def"`)
require.Equal(t, "def", cfg.A)
}

func TestIssue931(t *testing.T) {
type item struct {
Name string
}

type items struct {
Slice []item
}

its := items{[]item{{"a"}, {"b"}}}

b := []byte(`
[[Slice]]
Name = 'c'

[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]item{{"c"}, {"d"}}}, its)
}

func TestIssue931Interface(t *testing.T) {
type items struct {
Slice interface{}
}

type item = map[string]interface{}

its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}}

b := []byte(`
[[Slice]]
Name = 'c'

[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}

func TestIssue931SliceInterface(t *testing.T) {
type items struct {
Slice []interface{}
}

type item = map[string]interface{}

its := items{
[]interface{}{
item{"Name": "a"},
item{"Name": "b"},
},
}

b := []byte(`
[[Slice]]
Name = 'c'

[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}

func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct {
desc string
Expand Down
Loading