From bf95e2482b4e0b4afe256e71ff7617b3e9e84573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Fri, 29 Mar 2024 17:40:02 +0100 Subject: [PATCH 1/3] Decoding/Encoding for []uint16, []uint32 and []uint64 --- decoder.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++ decoder_test.go | 24 +++++++++++++++ encoder.go | 57 +++++++++++++++++++++++++++++++++++ encoder_test.go | 56 +++++++++++++++++++++++++++++++++++ generate.go | 13 +++++++- 5 files changed, 228 insertions(+), 1 deletion(-) diff --git a/decoder.go b/decoder.go index 29e45c6..2017ff9 100644 --- a/decoder.go +++ b/decoder.go @@ -1,6 +1,7 @@ package scale import ( + "encoding/binary" "errors" "fmt" "io" @@ -330,6 +331,84 @@ func DecodeByteArray(d *Decoder, value []byte) (int, error) { return d.read(value) } +func DecodeUint16Slice(d *Decoder) ([]uint16, int, error) { + return DecodeUint16SliceWithLimit(d, d.maxElements) +} + +func DecodeUint16SliceWithLimit(d *Decoder, limit uint32) ([]uint16, int, error) { + lth, total, err := DecodeLen(d, limit) + if err != nil { + return nil, 0, err + } + if lth == 0 { + return nil, total, nil + } + values := make([]uint16, lth) + + for i := uint32(0); i < lth; i++ { + n, err := d.read(d.scratch[:2]) + if err != nil { + return nil, 0, err + } + total += n + values[i] = binary.LittleEndian.Uint16(d.scratch[:2]) + } + + return values, total, nil +} + +func DecodeUint32Slice(d *Decoder) ([]uint32, int, error) { + return DecodeUint32SliceWithLimit(d, d.maxElements) +} + +func DecodeUint32SliceWithLimit(d *Decoder, limit uint32) ([]uint32, int, error) { + lth, total, err := DecodeLen(d, limit) + if err != nil { + return nil, 0, err + } + if lth == 0 { + return nil, total, nil + } + values := make([]uint32, lth) + + for i := uint32(0); i < lth; i++ { + n, err := d.read(d.scratch[:4]) + if err != nil { + return nil, 0, err + } + total += n + values[i] = binary.LittleEndian.Uint32(d.scratch[:4]) + } + + return values, total, nil +} + +func DecodeUint64Slice(d *Decoder) ([]uint64, int, error) { + return DecodeUint64SliceWithLimit(d, d.maxElements) +} + +func DecodeUint64SliceWithLimit(d *Decoder, limit uint32) ([]uint64, int, error) { + lth, total, err := DecodeLen(d, limit) + if err != nil { + return nil, 0, err + } + if lth == 0 { + return nil, total, nil + } + values := make([]uint64, lth) + + for i := uint32(0); i < lth; i++ { + n, err := d.read(d.scratch[:8]) + if err != nil { + return nil, 0, err + } + total += n + values[i] = binary.LittleEndian.Uint64(d.scratch[:8]) + } + + return values, total, nil +} + func DecodeString(d *Decoder) (string, int, error) { return DecodeStringWithLimit(d, d.maxElements) } diff --git a/decoder_test.go b/decoder_test.go index 1110d88..4f61711 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -39,6 +39,12 @@ func testEncode(tb testing.TB, value any) []byte { _, err = EncodeCompact64(enc, val) case []byte: _, err = EncodeByteSlice(enc, val) + case []uint16: + _, err = EncodeUint16Slice(enc, val) + case []uint32: + _, err = EncodeUint32Slice(enc, val) + case []uint64: + _, err = EncodeUint64Slice(enc, val) case string: _, err = EncodeString(enc, val) case []string: @@ -65,6 +71,12 @@ func expectEqual(tb testing.TB, value any, r io.Reader) { rst, _, err = DecodeCompact64(dec) case []byte: rst, _, err = DecodeByteSlice(dec) + case []uint16: + rst, _, err = DecodeUint16Slice(dec) + case []uint32: + rst, _, err = DecodeUint32Slice(dec) + case []uint64: + rst, _, err = DecodeUint64Slice(dec) case string: rst, _, err = DecodeString(dec) case []string: @@ -107,6 +119,18 @@ func TestReadFull(t *testing.T) { desc: "string slice", expect: []string{"qwe123", "dsa456"}, }, + { + desc: "uint16 slice", + expect: []uint16{0, 1, 2, math.MaxUint8, math.MaxUint16}, + }, + { + desc: "uint32 slice", + expect: []uint32{0, 1, 2, math.MaxUint8, math.MaxUint16, math.MaxUint32}, + }, + { + desc: "uint64 slice", + expect: []uint64{0, 1, 2, math.MaxUint32, math.MaxUint64}, + }, } { t.Run(tc.desc, func(t *testing.T) { t.Run("full", func(t *testing.T) { diff --git a/encoder.go b/encoder.go index 440ab29..05cae50 100644 --- a/encoder.go +++ b/encoder.go @@ -105,6 +105,63 @@ func EncodeByteArray(e *Encoder, value []byte) (int, error) { return e.w.Write(value) } +func EncodeUint16Slice(e *Encoder, value []uint16) (int, error) { + return EncodeUint16SliceWithLimit(e, value, e.maxElements) +} + +func EncodeUint16SliceWithLimit(e *Encoder, value []uint16, limit uint32) (int, error) { + total, err := EncodeLen(e, uint32(len(value)), limit) + if err != nil { + return 0, err + } + for _, v := range value { + scratch := e.scratch[:2] + binary.LittleEndian.PutUint16(scratch, v) + e.w.Write(scratch) + total += 2 + } + + return total, nil +} + +func EncodeUint32Slice(e *Encoder, value []uint32) (int, error) { + return EncodeUint32SliceWithLimit(e, value, e.maxElements) +} + +func EncodeUint32SliceWithLimit(e *Encoder, value []uint32, limit uint32) (int, error) { + total, err := EncodeLen(e, uint32(len(value)), limit) + if err != nil { + return 0, err + } + for _, v := range value { + scratch := e.scratch[:4] + binary.LittleEndian.PutUint32(scratch, v) + e.w.Write(scratch) + total += 4 + } + + return total, nil +} + +func EncodeUint64Slice(e *Encoder, value []uint64) (int, error) { + return EncodeUint64SliceWithLimit(e, value, e.maxElements) +} + +func EncodeUint64SliceWithLimit(e *Encoder, value []uint64, limit uint32) (int, error) { + total, err := EncodeLen(e, uint32(len(value)), limit) + if err != nil { + return 0, err + } + for _, v := range value { + scratch := e.scratch[:8] + binary.LittleEndian.PutUint64(scratch, v) + e.w.Write(scratch) + total += 8 + } + + return total, nil +} + func EncodeString(e *Encoder, value string) (int, error) { return EncodeStringWithLimit(e, value, e.maxElements) } diff --git a/encoder_test.go b/encoder_test.go index 4353d50..10c8fc1 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -2,6 +2,7 @@ package scale import ( "bytes" + "encoding/hex" "math" "testing" @@ -128,6 +129,32 @@ func uint64TestCases() []compactTestCase[uint64] { } } +func mustDecodeHex(hexStr string) []byte { + b, err := hex.DecodeString(hexStr) + if err != nil { + panic(err) + } + return b +} + +func uint16SliceTestCases() []compactTestCase[[]uint16] { + return []compactTestCase[[]uint16]{ + {[]uint16{4, 8, 15, 16, 23, 42}, mustDecodeHex("18040008000f00100017002a00")}, + } +} + +func uint32SliceTestCases() []compactTestCase[[]uint32] { + return []compactTestCase[[]uint32]{ + {[]uint32{4, 8, 15, 16, 23, 42}, mustDecodeHex("1804000000080000000f00000010000000170000002a000000")}, + } +} + +func uint64SliceTestCases() []compactTestCase[[]uint64] { + return []compactTestCase[[]uint64]{ + {[]uint64{4, 8, 42}, mustDecodeHex("0c040000000000000008000000000000002a00000000000000")}, + } +} + func encodeTest[T any](t *testing.T, value T, expect []byte) { buf := bytes.NewBuffer(nil) enc := NewEncoder(buf) @@ -141,6 +168,14 @@ func encodeTest[T any](t *testing.T, value T, expect []byte) { _, err = EncodeCompact32(enc, typed) case uint64: _, err = EncodeCompact64(enc, typed) + case []uint16: + _, err = EncodeUint16Slice(enc, typed) + case []uint32: + _, err = EncodeUint32Slice(enc, typed) + case []uint64: + _, err = EncodeUint64Slice(enc, typed) + default: + t.Fatal("unsupported type") } require.NoError(t, err) require.Equal(t, expect, buf.Bytes()) @@ -175,4 +210,25 @@ func TestEncodeCompactIntegers(t *testing.T) { }) } }) + t.Run("[]uint16", func(t *testing.T) { + for _, tc := range uint16SliceTestCases() { + t.Run("", func(t *testing.T) { + encodeTest(t, tc.value, tc.expect) + }) + } + }) + t.Run("[]uint32", func(t *testing.T) { + for _, tc := range uint32SliceTestCases() { + t.Run("", func(t *testing.T) { + encodeTest(t, tc.value, tc.expect) + }) + } + }) + t.Run("[]uint64", func(t *testing.T) { + for _, tc := range uint64SliceTestCases() { + t.Run("", func(t *testing.T) { + encodeTest(t, tc.value, tc.expect) + }) + } + }) } diff --git a/generate.go b/generate.go index 4d90290..437b792 100644 --- a/generate.go +++ b/generate.go @@ -17,7 +17,7 @@ const ( // nolint package {{ .Package }} - + import ( "github.com/spacemeshos/go-scale" {{ range $pkg, $short := .Imported }}"{{ $pkg }}" @@ -327,6 +327,17 @@ func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType if field.Type.Elem().Kind() == reflect.Uint8 { return scaleType{Name: "ByteSliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil } + if field.Type.Elem().Name() == "uint16" { + return scaleType{Name: "Uint16SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil + } + // Note: `field.Type.Elem().Kind() == reflect.Uint32` catches things like + // type Foo uint32. + if field.Type.Elem().Name() == "uint32" { + return scaleType{Name: "Uint32SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil + } + if field.Type.Elem().Name() == "uint64" { + return scaleType{Name: "Uint64SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil + } return scaleType{Name: "StructSliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil case reflect.Array: if field.Type.Elem().Kind() == reflect.Uint8 { From cb3c9a5bd5588fdae5bd99113ca17b9de84bd5b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Sat, 30 Mar 2024 09:27:01 +0100 Subject: [PATCH 2/3] Review feedback. Encode primitive slices as compact --- decoder.go | 13 +++-- encoder.go | 27 ++++++----- encoder_test.go | 6 +-- generate.go | 11 ++--- generate_test.go | 122 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 28 deletions(-) create mode 100644 generate_test.go diff --git a/decoder.go b/decoder.go index 2017ff9..ecacdc2 100644 --- a/decoder.go +++ b/decoder.go @@ -1,7 +1,6 @@ package scale import ( - "encoding/binary" "errors" "fmt" "io" @@ -346,12 +345,12 @@ func DecodeUint16SliceWithLimit(d *Decoder, limit uint32) ([]uint16, int, error) values := make([]uint16, lth) for i := uint32(0); i < lth; i++ { - n, err := d.read(d.scratch[:2]) + v, n, err := DecodeCompact16(d) if err != nil { return nil, 0, err } total += n - values[i] = binary.LittleEndian.Uint16(d.scratch[:2]) + values[i] = v } return values, total, nil @@ -372,12 +371,12 @@ func DecodeUint32SliceWithLimit(d *Decoder, limit uint32) ([]uint32, int, error) values := make([]uint32, lth) for i := uint32(0); i < lth; i++ { - n, err := d.read(d.scratch[:4]) + v, n, err := DecodeCompact32(d) if err != nil { return nil, 0, err } total += n - values[i] = binary.LittleEndian.Uint32(d.scratch[:4]) + values[i] = v } return values, total, nil @@ -398,12 +397,12 @@ func DecodeUint64SliceWithLimit(d *Decoder, limit uint32) ([]uint64, int, error) values := make([]uint64, lth) for i := uint32(0); i < lth; i++ { - n, err := d.read(d.scratch[:8]) + v, n, err := DecodeCompact64(d) if err != nil { return nil, 0, err } total += n - values[i] = binary.LittleEndian.Uint64(d.scratch[:8]) + values[i] = v } return values, total, nil diff --git a/encoder.go b/encoder.go index 05cae50..f2cbcf6 100644 --- a/encoder.go +++ b/encoder.go @@ -115,10 +115,11 @@ func EncodeUint16SliceWithLimit(e *Encoder, value []uint16, limit uint32) (int, return 0, err } for _, v := range value { - scratch := e.scratch[:2] - binary.LittleEndian.PutUint16(scratch, v) - e.w.Write(scratch) - total += 2 + n, err := EncodeCompact16(e, v) + if err != nil { + return 0, err + } + total += n } return total, nil @@ -134,10 +135,11 @@ func EncodeUint32SliceWithLimit(e *Encoder, value []uint32, limit uint32) (int, return 0, err } for _, v := range value { - scratch := e.scratch[:4] - binary.LittleEndian.PutUint32(scratch, v) - e.w.Write(scratch) - total += 4 + n, err := EncodeCompact32(e, v) + if err != nil { + return 0, err + } + total += n } return total, nil @@ -153,10 +155,11 @@ func EncodeUint64SliceWithLimit(e *Encoder, value []uint64, limit uint32) (int, return 0, err } for _, v := range value { - scratch := e.scratch[:8] - binary.LittleEndian.PutUint64(scratch, v) - e.w.Write(scratch) - total += 8 + n, err := EncodeCompact64(e, v) + if err != nil { + return 0, err + } + total += n } return total, nil diff --git a/encoder_test.go b/encoder_test.go index 10c8fc1..3e5a6cd 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -139,19 +139,19 @@ func mustDecodeHex(hexStr string) []byte { func uint16SliceTestCases() []compactTestCase[[]uint16] { return []compactTestCase[[]uint16]{ - {[]uint16{4, 8, 15, 16, 23, 42}, mustDecodeHex("18040008000f00100017002a00")}, + {[]uint16{4, 15, 23, math.MaxUint16}, mustDecodeHex("10103c5cfeff0300")}, } } func uint32SliceTestCases() []compactTestCase[[]uint32] { return []compactTestCase[[]uint32]{ - {[]uint32{4, 8, 15, 16, 23, 42}, mustDecodeHex("1804000000080000000f00000010000000170000002a000000")}, + {[]uint32{4, 15, 23, math.MaxUint32}, mustDecodeHex("10103c5c03ffffffff")}, } } func uint64SliceTestCases() []compactTestCase[[]uint64] { return []compactTestCase[[]uint64]{ - {[]uint64{4, 8, 42}, mustDecodeHex("0c040000000000000008000000000000002a00000000000000")}, + {[]uint64{4, 15, 23, math.MaxUint64}, mustDecodeHex("10103c5c13ffffffffffffffff")}, } } diff --git a/generate.go b/generate.go index 437b792..2b6c091 100644 --- a/generate.go +++ b/generate.go @@ -278,6 +278,7 @@ func getDecodeModifier(parentType reflect.Type, field reflect.StructField) strin func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType, error) { decodeModifier := getDecodeModifier(parentType, field) + encodableType := reflect.TypeOf((*Encodable)(nil)).Elem() switch field.Type.Kind() { case reflect.Bool: @@ -324,18 +325,16 @@ func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType if maxElements == 0 { return scaleType{}, errors.New("slices must have max scale tag") } - if field.Type.Elem().Kind() == reflect.Uint8 { + if field.Type.Elem().Kind() == reflect.Uint8 && !field.Type.Elem().Implements(encodableType) { return scaleType{Name: "ByteSliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil } - if field.Type.Elem().Name() == "uint16" { + if field.Type.Elem().Kind() == reflect.Uint16 && !field.Type.Elem().Implements(encodableType) { return scaleType{Name: "Uint16SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil } - // Note: `field.Type.Elem().Kind() == reflect.Uint32` catches things like - // type Foo uint32. - if field.Type.Elem().Name() == "uint32" { + if field.Type.Elem().Kind() == reflect.Uint32 && !field.Type.Elem().Implements(encodableType) { return scaleType{Name: "Uint32SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil } - if field.Type.Elem().Name() == "uint64" { + if field.Type.Elem().Kind() == reflect.Uint64 && !field.Type.Elem().Implements(encodableType) { return scaleType{Name: "Uint64SliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil } return scaleType{Name: "StructSliceWithLimit", Args: fmt.Sprintf(", %d", maxElements)}, nil diff --git a/generate_test.go b/generate_test.go new file mode 100644 index 0000000..2d73930 --- /dev/null +++ b/generate_test.go @@ -0,0 +1,122 @@ +package scale + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +type newU16 uint16 + +func (newU16) EncodeScale(enc *Encoder) (int, error) { + panic("uninmplemented") +} + +type newU32 uint32 + +func (newU32) EncodeScale(enc *Encoder) (int, error) { + panic("uninmplemented") +} + +type newU64 uint64 + +func (newU64) EncodeScale(enc *Encoder) (int, error) { + panic("uninmplemented") +} + +func Test_getScaleType_Slices(t *testing.T) { + t.Run("[]uint16", func(t *testing.T) { + type Foo struct { + Slice []uint16 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint16SliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint16 (implements Encodable)", func(t *testing.T) { + type Foo struct { + Slice []newU16 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "StructSliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint16 (doesn't implement Encodable)", func(t *testing.T) { + type newT uint16 + type Foo struct { + Slice []newT `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint16SliceWithLimit", scaleT.Name) + }) + t.Run("[]uint32", func(t *testing.T) { + type Foo struct { + Slice []uint32 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint32SliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint32 (implements Encodable)", func(t *testing.T) { + type Foo struct { + Slice []newU32 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "StructSliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint32 (doesn't implement Encodable)", func(t *testing.T) { + type newT uint32 + type Foo struct { + Slice []newT `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint32SliceWithLimit", scaleT.Name) + }) + t.Run("[]uint64", func(t *testing.T) { + type Foo struct { + Slice []uint64 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint64SliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint64 (implements Encodable)", func(t *testing.T) { + type Foo struct { + Slice []newU64 `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "StructSliceWithLimit", scaleT.Name) + }) + t.Run("[]newUint64 (doesn't implement Encodable)", func(t *testing.T) { + type newT uint64 + type Foo struct { + Slice []newT `scale:"max=2"` + } + + rtype := reflect.TypeOf(Foo{}) + scaleT, err := getScaleType(rtype, rtype.Field(0)) + require.NoError(t, err) + require.Equal(t, "Uint64SliceWithLimit", scaleT.Name) + }) +} From 52a46467b66998b3cbae2ec7aa1bc24f3632cc10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Sat, 30 Mar 2024 09:39:38 +0100 Subject: [PATCH 3/3] Update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ac5eb57..7eee456 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,10 @@ Object{} | concatenation of fields uint8 | compact u8 [TODO no need for compact u8] uint16 | compact u16 uint32 | compact u32 -uint32 | compact u64 +uint34 | compact u64 +[]uint16 | length prefixed (compact u32) followed by compact u16s +[]uint32 | length prefixed (compact u32) followed by compact u32s +[]uint64 | length prefixed (compact u32) followed by compact u64s [...]Object | array with objects. encoded by consecutively encoding every object []Object | slice with objects. prefixed with compact u32