Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reflect Marshaler #1592

Merged
merged 15 commits into from
Oct 10, 2024
Merged
170 changes: 170 additions & 0 deletions abi/dynamic/reflect_marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (C) 2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package dynamic

import (
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"

"github.com/ava-labs/avalanchego/utils/wrappers"
"golang.org/x/text/cases"
"golang.org/x/text/language"

"github.com/ava-labs/hypersdk/abi"
"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/consts"
)

var ErrTypeNotFound = errors.New("type not found in ABI")

func Marshal(inputABI abi.ABI, typeName string, jsonData string) ([]byte, error) {
if _, ok := findABIType(inputABI, typeName); !ok {
return nil, fmt.Errorf("marshalling %s: %w", typeName, ErrTypeNotFound)
}

typeCache := make(map[string]reflect.Type)

typ, err := getReflectType(typeName, inputABI, typeCache)
if err != nil {
return nil, fmt.Errorf("failed to get reflect type: %w", err)
}

value := reflect.New(typ).Interface()

if err := json.Unmarshal([]byte(jsonData), value); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON data: %w", err)
}

writer := codec.NewWriter(0, consts.NetworkSizeLimit)
if err := codec.LinearCodec.MarshalInto(value, writer.Packer); err != nil {
return nil, fmt.Errorf("failed to marshal struct: %w", err)
}

return writer.Bytes(), nil
}

func Unmarshal(inputABI abi.ABI, typeName string, data []byte) (string, error) {
if _, ok := findABIType(inputABI, typeName); !ok {
return "", fmt.Errorf("unmarshalling %s: %w", typeName, ErrTypeNotFound)
}

typeCache := make(map[string]reflect.Type)

typ, err := getReflectType(typeName, inputABI, typeCache)
if err != nil {
return "", fmt.Errorf("failed to get reflect type: %w", err)
}

value := reflect.New(typ).Interface()

packer := wrappers.Packer{
Bytes: data,
MaxSize: consts.NetworkSizeLimit,
}
if err := codec.LinearCodec.UnmarshalFrom(&packer, value); err != nil {
return "", fmt.Errorf("failed to unmarshal data: %w", err)
}

jsonData, err := json.Marshal(value)
if err != nil {
return "", fmt.Errorf("failed to marshal struct to JSON: %w", err)
}

return string(jsonData), nil
}

// Matches fixed-size arrays like [32]uint8
var fixedSizeArrayRegex = regexp.MustCompile(`^\[(\d+)\](.+)$`)

func getReflectType(abiTypeName string, inputABI abi.ABI, typeCache map[string]reflect.Type) (reflect.Type, error) {
switch abiTypeName {
case "string":
return reflect.TypeOf(""), nil
case "uint8":
return reflect.TypeOf(uint8(0)), nil
case "uint16":
return reflect.TypeOf(uint16(0)), nil
case "uint32":
return reflect.TypeOf(uint32(0)), nil
case "uint64":
return reflect.TypeOf(uint64(0)), nil
case "int8":
return reflect.TypeOf(int8(0)), nil
case "int16":
return reflect.TypeOf(int16(0)), nil
case "int32":
return reflect.TypeOf(int32(0)), nil
case "int64":
return reflect.TypeOf(int64(0)), nil
case "Address":
return reflect.TypeOf(codec.Address{}), nil
default:
// golang slices
if strings.HasPrefix(abiTypeName, "[]") {
elemType, err := getReflectType(strings.TrimPrefix(abiTypeName, "[]"), inputABI, typeCache)
if err != nil {
return nil, err
}
return reflect.SliceOf(elemType), nil
}

// golang arrays

if match := fixedSizeArrayRegex.FindStringSubmatch(abiTypeName); match != nil {
sizeStr := match[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is unnecessary as it's only used on the next line and it doesn't add any additional context beyond what size itself does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exclusively for readability

size, err := strconv.Atoi(sizeStr)
if err != nil {
return nil, fmt.Errorf("failed to convert size to int: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using %w makes the error part of your public API, which I don't think is the desired behaviour here. %v is more appropriate IMO.

}
elemType, err := getReflectType(match[2], inputABI, typeCache)
if err != nil {
return nil, err
}
return reflect.ArrayOf(size, elemType), nil
}

// For custom types, recursively construct the struct type
if cachedType, ok := typeCache[abiTypeName]; ok {
containerman17 marked this conversation as resolved.
Show resolved Hide resolved
return cachedType, nil
}

abiType, ok := findABIType(inputABI, abiTypeName)
if !ok {
return nil, fmt.Errorf("type %s not found in ABI", abiTypeName)
}

// It is a struct, as we don't support anything else as custom types
fields := make([]reflect.StructField, len(abiType.Fields))
containerman17 marked this conversation as resolved.
Show resolved Hide resolved
for i, field := range abiType.Fields {
fieldType, err := getReflectType(field.Type, inputABI, typeCache)
if err != nil {
return nil, err
}
fields[i] = reflect.StructField{
Name: cases.Title(language.English).String(field.Name),
containerman17 marked this conversation as resolved.
Show resolved Hide resolved
Type: fieldType,
Tag: reflect.StructTag(fmt.Sprintf(`serialize:"true" json:"%s"`, field.Name)),
}
}

structType := reflect.StructOf(fields)
typeCache[abiTypeName] = structType

return structType, nil
}
}

func findABIType(inputABI abi.ABI, typeName string) (abi.Type, bool) {
for _, typ := range inputABI.Types {
if typ.Name == typeName {
return typ, true
}
}
return abi.Type{}, false
}
94 changes: 94 additions & 0 deletions abi/dynamic/reflect_marshal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package dynamic

import (
"encoding/hex"
"encoding/json"
"os"
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/ava-labs/hypersdk/abi"
)

func TestDynamicMarshal(t *testing.T) {
require := require.New(t)

abiJSON := mustReadFile(t, "../testdata/abi.json")
var abi abi.ABI

err := json.Unmarshal(abiJSON, &abi)
require.NoError(err)

testCases := []struct {
name string
typeName string
}{
{"empty", "MockObjectSingleNumber"},
{"uint16", "MockObjectSingleNumber"},
{"numbers", "MockObjectAllNumbers"},
{"arrays", "MockObjectArrays"},
{"transfer", "MockActionTransfer"},
{"transferField", "MockActionWithTransfer"},
{"transfersArray", "MockActionWithTransferArray"},
{"strBytes", "MockObjectStringAndBytes"},
{"strByteZero", "MockObjectStringAndBytes"},
{"strBytesEmpty", "MockObjectStringAndBytes"},
{"strOnly", "MockObjectStringAndBytes"},
{"outer", "Outer"},
{"fixedBytes", "FixedBytes"},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Read the JSON data
jsonData := mustReadFile(t, "../testdata/"+tc.name+".json")

objectBytes, err := Marshal(abi, tc.typeName, string(jsonData))
containerman17 marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(err)

// Compare with expected hex
expectedHex := string(mustReadFile(t, "../testdata/"+tc.name+".hex"))
expectedHex = strings.TrimSpace(expectedHex)
require.Equal(expectedHex, hex.EncodeToString(objectBytes))

unmarshaledJSON, err := Unmarshal(abi, tc.typeName, objectBytes)
require.NoError(err)

// Compare with expected JSON
require.JSONEq(string(jsonData), unmarshaledJSON)
})
}
}

func TestDynamicMarshalErrors(t *testing.T) {
require := require.New(t)

abiJSON := mustReadFile(t, "../testdata/abi.json")
var abi abi.ABI

err := json.Unmarshal(abiJSON, &abi)
require.NoError(err)

// Test malformed JSON
malformedJSON := `{"uint8": 42, "uint16": 1000, "uint32": 100000, "uint64": 10000000000, "int8": -42, "int16": -1000, "int32": -100000, "int64": -10000000000,`
_, err = Marshal(abi, "MockObjectAllNumbers", malformedJSON)
require.Contains(err.Error(), "unexpected end of JSON input")
aaronbuchwald marked this conversation as resolved.
Show resolved Hide resolved

// Test wrong struct name
jsonData := mustReadFile(t, "../testdata/numbers.json")
_, err = Marshal(abi, "NonExistentObject", string(jsonData))
require.ErrorIs(err, ErrTypeNotFound)
}

func mustReadFile(t *testing.T, path string) []byte {
t.Helper()

content, err := os.ReadFile(path)
require.NoError(t, err)
return content
}
2 changes: 1 addition & 1 deletion abi/testdata/transfer.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"to": "0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"to": "0x0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"value": 1000,
"memo": "aGk="
}
2 changes: 1 addition & 1 deletion abi/testdata/transferField.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"transfer": {
"to": "0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"to": "0x0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"value": 1000,
"memo": "aGk="
}
Expand Down
4 changes: 2 additions & 2 deletions abi/testdata/transfersArray.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"transfers": [
{
"to": "0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"to": "0x0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"value": 1000,
"memo": "aGk="
},
{
"to": "0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"to": "0x0102030405060708090a0b0c0d0e0f101112131400000000000000000000000000",
"value": 1000,
"memo": "aGk="
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
golang.org/x/crypto v0.22.0
golang.org/x/exp v0.0.0-20231127185646-65229373498e
golang.org/x/sync v0.7.0
golang.org/x/text v0.14.0
google.golang.org/grpc v1.62.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v2 v2.4.0
Expand Down Expand Up @@ -143,7 +144,6 @@ require (
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/term v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.17.0 // indirect
gonum.org/v1/gonum v0.11.0 // indirect
Expand Down
Loading