Skip to content

Commit

Permalink
fix: check error before Set/Unset/Add() (#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
AsterDY authored Aug 16, 2023
1 parent c7754d3 commit b894f41
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 73 deletions.
108 changes: 36 additions & 72 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,11 @@ import (
`fmt`
`strconv`
`unsafe`
`reflect`

`github.com/bytedance/sonic/internal/native/types`
`github.com/bytedance/sonic/internal/rt`
)

const (
_CAP_BITS = 32
_LEN_MASK = 1 << _CAP_BITS - 1

_NODE_SIZE = unsafe.Sizeof(Node{})
_PAIR_SIZE = unsafe.Sizeof(Pair{})
)

const (
_V_NONE types.ValueType = 0
_V_NODE_BASE types.ValueType = 1 << 5
Expand All @@ -61,10 +52,6 @@ const (
V_ANY = int(_V_ANY)
)

var (
byteType = rt.UnpackType(reflect.TypeOf(byte(0)))
)

type Node struct {
t types.ValueType
l uint
Expand Down Expand Up @@ -143,6 +130,9 @@ func (self *Node) isAny() bool {

// Raw returns json representation of the node,
func (self *Node) Raw() (string, error) {
if self == nil {
return "", ErrNotExist
}
if !self.IsRaw() {
buf, err := self.MarshalJSON()
return rt.Mem2Str(buf), err
Expand All @@ -157,7 +147,7 @@ func (self *Node) checkRaw() error {
if self.IsRaw() {
self.parseRaw(false)
}
return nil
return self.Check()
}

// Bool returns bool value represented by this node,
Expand Down Expand Up @@ -538,15 +528,18 @@ func (self *Node) Cap() (int, error) {
//
// If self is V_NONE or V_NULL, it becomes V_OBJECT and sets the node at the key.
func (self *Node) Set(key string, node Node) (bool, error) {
if self != nil && (self.t == _V_NONE || self.t == types.V_NULL) {
*self = NewObject([]Pair{{key, node}})
return false, nil
if err := self.Check(); err != nil {
return false, err
}

if err := node.Check(); err != nil {
return false, err
}

if self.t == _V_NONE || self.t == types.V_NULL {
*self = NewObject([]Pair{{key, node}})
return false, nil
}

p := self.Get(key)

if !p.Exists() {
Expand Down Expand Up @@ -575,7 +568,9 @@ func (self *Node) SetAny(key string, val interface{}) (bool, error) {
// Unset RESET the node of given key under object parent, and reports if the key has existed.
// WARN: After conducting `UnsetXX()`, the node's length WON'T change
func (self *Node) Unset(key string) (bool, error) {
self.must(types.V_OBJECT, "an object")
if err := self.should(types.V_OBJECT, "an object"); err != nil {
return false, err
}
p, i := self.skipKey(key)
if !p.Exists() {
return false, nil
Expand All @@ -591,10 +586,18 @@ func (self *Node) Unset(key string) (bool, error) {
//
// The index must be within self's children.
func (self *Node) SetByIndex(index int, node Node) (bool, error) {
if err := self.Check(); err != nil {
return false, err
}
if err := node.Check(); err != nil {
return false, err
}

if index == 0 && (self.t == _V_NONE || self.t == types.V_NULL) {
*self = NewArray([]Node{node})
return false, nil
}

p := self.Index(index)
if !p.Exists() {
return false, ErrNotExist
Expand All @@ -614,6 +617,10 @@ func (self *Node) SetAnyByIndex(index int, val interface{}) (bool, error) {
// UnsetByIndex remove the node of given index
// WARN: After conducting `UnsetXX()`, the node's length WON'T change
func (self *Node) UnsetByIndex(index int) (bool, error) {
if err := self.Check(); err != nil {
return false, err
}

var p *Node
it := self.itype()
if it == types.V_ARRAY {
Expand Down Expand Up @@ -647,6 +654,10 @@ func (self *Node) UnsetByIndex(index int) (bool, error) {
//
// If self is V_NONE or V_NULL, it becomes V_ARRAY and sets the node at index 0.
func (self *Node) Add(node Node) error {
if err := self.Check(); err != nil {
return err
}

if self != nil && (self.t == _V_NONE || self.t == types.V_NULL) {
*self = NewArray([]Node{node})
return nil
Expand Down Expand Up @@ -846,7 +857,7 @@ func (self *Node) unsafeMap() (*linkedPairs, error) {

// SortKeys sorts children of a V_OBJECT node in ascending key-order.
// If recurse is true, it recursively sorts children's children as long as a V_OBJECT node is found.
func (self *Node) SortKeys(recurse bool) (err error) {
func (self *Node) SortKeys(recurse bool) error {
ps, err := self.unsafeMap()
if err != nil {
return err
Expand All @@ -867,7 +878,9 @@ func (self *Node) SortKeys(recurse bool) (err error) {
}
return true
}
self.ForEach(sc)
if err := self.ForEach(sc); err != nil {
return err
}
}
return nil
}
Expand Down Expand Up @@ -1103,9 +1116,8 @@ func (self *Node) LoadAll() error {
// Load loads the node's children as parsed.
// After calling it, only the node itself can be used on concurrency (not include its children)
func (self *Node) Load() error {
if self.IsRaw() {
self.parseRaw(false)
return self.Load()
if err := self.checkRaw(); err != nil {
return err
}

switch self.t {
Expand All @@ -1120,40 +1132,6 @@ func (self *Node) Load() error {

/**---------------------------------- Internal Helper Methods ----------------------------------**/

var (
_NODE_TYPE = rt.UnpackEface(Node{}).Type
_PAIR_TYPE = rt.UnpackEface(Pair{}).Type
)

// func (self *Node) setCapAndLen(cap int, len int) {
// if self.t == types.V_ARRAY || self.t == types.V_OBJECT || self.t == _V_ARRAY_LAZY || self.t == _V_OBJECT_LAZY {
// self.l = uint32(len)
// self.c = uint32(cap)
// } else {
// panic("value does not have a length")
// }
// }

func (self *Node) unsafe_next() *Node {
return (*Node)(unsafe.Pointer(uintptr(unsafe.Pointer(self)) + _NODE_SIZE))
}

func (self *Pair) unsafe_next() *Pair {
return (*Pair)(unsafe.Pointer(uintptr(unsafe.Pointer(self)) + _PAIR_SIZE))
}

func (self *Node) must(t types.ValueType, s string) {
if err := self.checkRaw(); err != nil {
panic(err)
}
if err := self.Check(); err != nil {
panic(err)
}
if self.itype() != t {
panic("value cannot be represented as " + s)
}
}

func (self *Node) should(t types.ValueType, s string) error {
if err := self.checkRaw(); err != nil {
return err
Expand Down Expand Up @@ -1463,20 +1441,6 @@ func (self *Node) toGenericObjectUseNode() (map[string]Node, error) {
return out, nil
}

func (self *Node) toGenericObjectUsePair() ([]Pair, error) {
var nb = self.len()
if nb == 0 {
return []Pair{}, nil
}

var s = (*linkedPairs)(self.p)
var out = make([]Pair, nb)
s.ToSlice(out)

/* all done */
return out, nil
}

/**------------------------------------ Factory Methods ------------------------------------**/

var (
Expand Down
156 changes: 155 additions & 1 deletion ast/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,161 @@ func TestTypeCast(t *testing.T) {
}
}

func TestCheckError(t *testing.T) {
func TestCheckError_Nil(t *testing.T) {
nill := (*Node)(nil)
if nill.Valid() || nill.Check() == nil {
t.Fail()
}
if nill.Get("").Check() == nil {
t.Fatal()
}
if nill.GetByPath("").Check() == nil {
t.Fatal()
}
if nill.Index(1).Check() == nil {
t.Fatal()
}
if nill.IndexOrGet(1, "a").Check() == nil {
t.Fatal()
}
if nill.IndexPair(1) != nil {
t.Fatal()
}
if _, err := nill.Set("", Node{}); err == nil {
t.Fatal()
}
if _, err := nill.SetByIndex(1, Node{}); err == nil {
t.Fatal()
}
if _, err := nill.SetAny("", 1); err == nil {
t.Fatal()
}
if _, err := nill.SetAnyByIndex(1, 1); err == nil {
t.Fatal()
}
if _, err := nill.Unset(""); err == nil {
t.Fatal()
}
if _, err := nill.UnsetByIndex(1); err == nil {
t.Fatal()
}
if err := nill.Add(Node{}); err == nil {
t.Fatal()
}
if err := nill.AddAny(1); err == nil {
t.Fatal()
}
}

func TestCheckError_None(t *testing.T) {
nill := Node{}
if !nill.Valid() || nill.Check() != nil {
t.Fail()
}
if nill.Get("").Check() == nil {
t.Fatal()
}
if nill.GetByPath("").Check() == nil {
t.Fatal()
}
if nill.Index(1).Check() == nil {
t.Fatal()
}
if nill.IndexOrGet(1, "a").Check() == nil {
t.Fatal()
}
if nill.IndexPair(1) != nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.Set("a", Node{}); err != nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.SetByIndex(0, Node{}); err != nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.SetByIndex(1, Node{}); err == nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.SetAny("a", 1); err != nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.SetAnyByIndex(0, 1); err != nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.SetAnyByIndex(1, 1); err == nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.Unset(""); err == nil {
t.Fatal()
}
nill = Node{}
if _, err := nill.UnsetByIndex(1); err == nil {
t.Fatal()
}
nill = Node{}
if err := nill.Add(Node{}); err != nil {
t.Fatal()
}
nill = Node{}
if err := nill.AddAny(1); err != nil {
t.Fatal()
}
}

func TestCheckError_Error(t *testing.T) {
nill := newError(types.ERR_EOF, "")
if nill.Valid() || nill.Check() == nil {
t.Fail()
}
if nill.Get("").Check() == nil {
t.Fatal()
}
if nill.GetByPath("").Check() == nil {
t.Fatal()
}
if nill.Index(1).Check() == nil {
t.Fatal()
}
if nill.IndexOrGet(1, "a").Check() == nil {
t.Fatal()
}
if nill.IndexPair(1) != nil {
t.Fatal()
}
if _, err := nill.Set("", Node{}); err == nil {
t.Fatal()
}
if _, err := nill.SetByIndex(1, Node{}); err == nil {
t.Fatal()
}
if _, err := nill.SetAny("", 1); err == nil {
t.Fatal()
}
if _, err := nill.SetAnyByIndex(1, 1); err == nil {
t.Fatal()
}
if _, err := nill.Unset(""); err == nil {
t.Fatal()
}
if _, err := nill.UnsetByIndex(1); err == nil {
t.Fatal()
}
if err := nill.Add(Node{}); err == nil {
t.Fatal()
}
if err := nill.AddAny(1); err == nil {
t.Fatal()
}
}

func TestCheckError_Empty(t *testing.T) {
empty := Node{}
if !empty.Valid() || empty.Check() != nil || empty.Error() != "" {
t.Fatal()
Expand Down

0 comments on commit b894f41

Please sign in to comment.