Skip to content

Commit

Permalink
Add limit on nesting depth for go-scale (#60)
Browse files Browse the repository at this point in the history
* Add limit on nesting depth for go-scale

* Fix linter complaints

* Review feedback
  • Loading branch information
fasmat authored Jun 20, 2023
1 parent d22d124 commit 62aa1b3
Show file tree
Hide file tree
Showing 12 changed files with 649 additions and 71 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Not implemented:
## Code generation

```bash
go install ./scalegen
go install github.com/spacemeshos/go-scale/scalegen
```

`//go:generate scalegen` will discover all struct types and derive EncodeScale/DecodeScale methods. To avoid structs autodiscovery use `-types=U8,U16`.
`//go:generate scalegen` will discover all struct types and derive EncodeScale/DecodeScale methods. To avoid struct auto-discovery use `-types=U8,U16`.
80 changes: 67 additions & 13 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ const (
maxUint30 = 1<<30 - 1
)

// ErrDecodeTooManyElements is returned when scale limit tag is used and collection has too many elements to decode.
var ErrDecodeTooManyElements = errors.New("too many elements to decode in collection with scale limit set")
var (
// ErrDecodeTooManyElements is returned when scale limit tag is used and collection has too many elements to decode.
ErrDecodeTooManyElements = errors.New("too many elements to decode in collection with scale limit set")

// ErrDecodeNestedTooDeep is returned when nested level is too deep.
ErrDecodeNestedTooDeep = errors.New("nested level is too deep")
)

type Decodable interface {
DecodeScale(*Decoder) (int, error)
Expand All @@ -26,17 +31,54 @@ type DecodablePtr[B any] interface {
*B
}

func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
type decoderOpts func(*Decoder)

// WithDecodeMaxNested sets the nested level of the decoder.
// A value of 0 means no nesting is allowed. The default value is 4.
func WithDecodeMaxNested(nested uint) decoderOpts {
return func(d *Decoder) {
d.maxNested = nested
}
}

// WithDecodeMaxElements sets the maximum number of elements allowed in a collection.
// The default value is 1 << 20.
func WithDecodeMaxElements(elements uint32) decoderOpts {
return func(e *Decoder) {
e.maxElements = elements
}
}

func (d *Decoder) Reset(r io.Reader) {
d.r = r
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader, opts ...decoderOpts) *Decoder {
d := &Decoder{
r: r,
maxNested: MaxNested,
maxElements: MaxElements,
}
for _, opt := range opts {
opt(d)
}
return d
}

type Decoder struct {
r io.Reader
scratch [9]byte
r io.Reader
scratch [9]byte
maxNested uint
maxElements uint32
}

func (d *Decoder) enterNested() error {
if d.maxNested == 0 {
return ErrDecodeNestedTooDeep
}
d.maxNested--
return nil
}

func (e *Decoder) leaveNested() {
e.maxNested++
}

func (d *Decoder) read(buf []byte) (int, error) {
Expand Down Expand Up @@ -265,7 +307,7 @@ func DecodeStruct[V any, H DecodablePtr[V]](d *Decoder) (V, int, error) {
}

func DecodeByteSlice(d *Decoder) ([]byte, int, error) {
return DecodeByteSliceWithLimit(d, MaxElements)
return DecodeByteSliceWithLimit(d, d.maxElements)
}

func DecodeByteSliceWithLimit(d *Decoder, limit uint32) ([]byte, int, error) {
Expand All @@ -289,7 +331,7 @@ func DecodeByteArray(d *Decoder, value []byte) (int, error) {
}

func DecodeString(d *Decoder) (string, int, error) {
return DecodeStringWithLimit(d, MaxElements)
return DecodeStringWithLimit(d, d.maxElements)
}

func DecodeStringWithLimit(d *Decoder, limit uint32) (string, int, error) {
Expand All @@ -298,10 +340,14 @@ func DecodeStringWithLimit(d *Decoder, limit uint32) (string, int, error) {
}

func DecodeStructSlice[V any, H DecodablePtr[V]](d *Decoder) ([]V, int, error) {
return DecodeStructSliceWithLimit[V, H](d, MaxElements)
return DecodeStructSliceWithLimit[V, H](d, d.maxElements)
}

func DecodeStructSliceWithLimit[V any, H DecodablePtr[V]](d *Decoder, limit uint32) ([]V, int, error) {
if err := d.enterNested(); err != nil {
return nil, 0, err
}
defer d.leaveNested()
lth, total, err := DecodeLen(d, limit)
if err != nil {
return nil, total, err
Expand All @@ -323,7 +369,7 @@ func DecodeStructSliceWithLimit[V any, H DecodablePtr[V]](d *Decoder, limit uint
}

func DecodeSliceOfByteSlice(d *Decoder) ([][]byte, int, error) {
return DecodeSliceOfByteSliceWithLimit(d, MaxElements)
return DecodeSliceOfByteSliceWithLimit(d, d.maxElements)
}

func DecodeSliceOfByteSliceWithLimit(d *Decoder, limit uint32) ([][]byte, int, error) {
Expand All @@ -349,7 +395,7 @@ func DecodeSliceOfByteSliceWithLimit(d *Decoder, limit uint32) ([][]byte, int, e
}

func DecodeStringSlice(d *Decoder) ([]string, int, error) {
return DecodeStringSliceWithLimit(d, MaxElements)
return DecodeStringSliceWithLimit(d, d.maxElements)
}

func DecodeStringSliceWithLimit(d *Decoder, limit uint32) ([]string, int, error) {
Expand All @@ -369,6 +415,10 @@ func DecodeStringSliceWithLimit(d *Decoder, limit uint32) ([]string, int, error)
}

func DecodeStructArray[V any, H DecodablePtr[V]](d *Decoder, value []V) (int, error) {
if err := d.enterNested(); err != nil {
return 0, err
}
defer d.leaveNested()
total := 0
for i := range value {
n, err := H(&value[i]).DecodeScale(d)
Expand All @@ -381,6 +431,10 @@ func DecodeStructArray[V any, H DecodablePtr[V]](d *Decoder, value []V) (int, er
}

func DecodeOption[V any, H DecodablePtr[V]](d *Decoder) (*V, int, error) {
if err := d.enterNested(); err != nil {
return nil, 0, err
}
defer d.leaveNested()
exists, total, err := DecodeBool(d)
if !exists || err != nil {
return nil, total, err
Expand Down
131 changes: 91 additions & 40 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ import (
"math/bits"
)

var MaxElements uint32 = 1 << 20
const (
// MaxElements is the maximum number of elements allowed in a collection if not set explicitly during encoding/decoding.
MaxElements uint32 = 1 << 20

// ErrEncodeTooManyElements is returned when scale limit tag is used and collection has too many elements to encode.
var ErrEncodeTooManyElements = errors.New("too many elements to encode in collection with scale limit set")
// MaxNested is the maximum nested level allowed if not set explicitly during encoding/decoding.
MaxNested uint = 4
)

var (
// ErrEncodeTooManyElements is returned when scale limit tag is used and collection has too many elements to encode.
ErrEncodeTooManyElements = errors.New("too many elements to encode in collection with scale limit set")

// ErrEncodeNestedTooDeep is returned when the depth of nested types exceeds the limit.
ErrEncodeNestedTooDeep = errors.New("nested level is too deep")
)

type Encodable interface {
EncodeScale(*Encoder) (int, error)
Expand All @@ -22,19 +33,59 @@ type EncodablePtr[B any] interface {
*B
}

type encoderOpts func(*Encoder)

// WithEncodeMaxNested sets the nested level of the encoder.
// A value of 0 means no nesting is allowed. The default value is 4.
func WithEncodeMaxNested(nested uint) encoderOpts {
return func(e *Encoder) {
e.maxNested = nested
}
}

// WithEncodeMaxElements sets the maximum number of elements allowed in a collection.
// The default value is 1 << 20.
func WithEncodeMaxElements(elements uint32) encoderOpts {
return func(e *Encoder) {
e.maxElements = elements
}
}

// NewEncoder returns a new encoder that writes to w.
// If w implements io.StringWriter, the returned encoder will be more efficient in encoding strings.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
func NewEncoder(w io.Writer, opts ...encoderOpts) *Encoder {
e := &Encoder{
w: w,
maxNested: MaxNested,
maxElements: MaxElements,
}
for _, opt := range opts {
opt(e)
}
return e
}

type Encoder struct {
w io.Writer
scratch [9]byte
w io.Writer
scratch [9]byte
maxNested uint
maxElements uint32
}

func (e *Encoder) enterNested() error {
if e.maxNested == 0 {
return ErrEncodeNestedTooDeep
}
e.maxNested--
return nil
}

func (e *Encoder) leaveNested() {
e.maxNested++
}

func EncodeByteSlice(e *Encoder, value []byte) (int, error) {
return EncodeByteSliceWithLimit(e, value, MaxElements)
return EncodeByteSliceWithLimit(e, value, e.maxElements)
}

func EncodeByteSliceWithLimit(e *Encoder, value []byte, limit uint32) (int, error) {
Expand All @@ -54,30 +105,30 @@ func EncodeByteArray(e *Encoder, value []byte) (int, error) {
}

func EncodeString(e *Encoder, value string) (int, error) {
return EncodeStringWithLimit(e, value, MaxElements)
return EncodeStringWithLimit(e, value, e.maxElements)
}

func EncodeStringWithLimit(e *Encoder, value string, limit uint32) (int, error) {
if sw, ok := e.w.(io.StringWriter); ok {
total, err := EncodeLen(e, uint32(len(value)), limit)
if err != nil {
return 0, err
}
n, err := sw.WriteString(value)
if err != nil {
return 0, err
}
return total + n, nil
total, err := EncodeLen(e, uint32(len(value)), limit)
if err != nil {
return 0, err
}

return EncodeByteSliceWithLimit(e, stringToBytes(value), limit)
n, err := io.WriteString(e.w, value)
if err != nil {
return 0, err
}
return total + n, nil
}

func EncodeStructSlice[V any, H EncodablePtr[V]](e *Encoder, value []V) (int, error) {
return EncodeStructSliceWithLimit[V, H](e, value, MaxElements)
return EncodeStructSliceWithLimit[V, H](e, value, e.maxElements)
}

func EncodeStructSliceWithLimit[V any, H EncodablePtr[V]](e *Encoder, value []V, limit uint32) (int, error) {
if err := e.enterNested(); err != nil {
return 0, err
}
defer e.leaveNested()
total, err := EncodeLen(e, uint32(len(value)), limit)
if err != nil {
return 0, err
Expand All @@ -92,17 +143,21 @@ func EncodeStructSliceWithLimit[V any, H EncodablePtr[V]](e *Encoder, value []V,
return total, nil
}

func EncodeSliceOfByteSlice(e *Encoder, value [][]byte) (int, error) {
return EncodeSliceOfByteSliceWithLimit(e, value, MaxElements)
func EncodeStringSlice(e *Encoder, value []string) (int, error) {
return EncodeStringSliceWithLimit(e, value, e.maxElements)
}

func EncodeSliceOfByteSliceWithLimit(e *Encoder, value [][]byte, limit uint32) (int, error) {
total, err := EncodeLen(e, uint32(len(value)), limit)
func EncodeStringSliceWithLimit(e *Encoder, value []string, limit uint32) (int, error) {
valueToBytes := make([][]byte, 0, len(value))
for i := range value {
valueToBytes = append(valueToBytes, stringToBytes(value[i]))
}
total, err := EncodeLen(e, uint32(len(valueToBytes)), limit)
if err != nil {
return 0, fmt.Errorf("EncodeLen failed: %w", err)
}
for _, byteSlice := range value {
n, err := EncodeByteSliceWithLimit(e, byteSlice, MaxElements)
for _, byteSlice := range valueToBytes {
n, err := EncodeByteSliceWithLimit(e, byteSlice, e.maxElements)
if err != nil {
return 0, fmt.Errorf("EncodeByteSliceWithLimit failed: %w", err)
}
Expand All @@ -111,19 +166,11 @@ func EncodeSliceOfByteSliceWithLimit(e *Encoder, value [][]byte, limit uint32) (
return total, nil
}

func EncodeStringSlice(e *Encoder, value []string) (int, error) {
return EncodeStringSliceWithLimit(e, value, MaxElements)
}

func EncodeStringSliceWithLimit(e *Encoder, value []string, limit uint32) (int, error) {
valueToBytes := make([][]byte, 0, len(value))
for i := range value {
valueToBytes = append(valueToBytes, stringToBytes(value[i]))
}
return EncodeSliceOfByteSliceWithLimit(e, valueToBytes, limit)
}

func EncodeStructArray[V any, H EncodablePtr[V]](e *Encoder, value []V) (int, error) {
if err := e.enterNested(); err != nil {
return 0, err
}
defer e.leaveNested()
total := 0
for i := range value {
n, err := H(&value[i]).EncodeScale(e)
Expand Down Expand Up @@ -239,6 +286,10 @@ func EncodeLen(e *Encoder, v uint32, limit uint32) (int, error) {
}

func EncodeOption[V any, H EncodablePtr[V]](e *Encoder, value *V) (int, error) {
if err := e.enterNested(); err != nil {
return 0, err
}
defer e.leaveNested()
if value == nil {
return EncodeBool(e, false)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/nested/nested.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nested
//go:generate scalegen

type NestedModule struct {
Value []byte
Value []byte `scale:"max=32"`
}

type Struct struct {
Expand Down
4 changes: 2 additions & 2 deletions examples/nested/nested_scale.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 62aa1b3

Please sign in to comment.