From 6e7fe1ab0ed3a24026ac74fa92c29e14002d9bc0 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Tue, 8 Oct 2024 16:54:10 -0500 Subject: [PATCH 01/19] initiated memo package --- go.mod | 1 + pkg/math/bits.go | 52 +++++ pkg/math/bits_test.go | 171 ++++++++++++++++ pkg/memo/arg.go | 52 +++++ pkg/memo/arg_test.go | 70 +++++++ pkg/memo/codec.go | 64 ++++++ pkg/memo/codec_abi.go | 92 +++++++++ pkg/memo/codec_abi_test.go | 227 +++++++++++++++++++++ pkg/memo/codec_compact.go | 259 ++++++++++++++++++++++++ pkg/memo/codec_compact_test.go | 350 +++++++++++++++++++++++++++++++++ pkg/memo/codec_test.go | 92 +++++++++ pkg/memo/fields_v0.go | 133 +++++++++++++ pkg/memo/memo.go | 169 ++++++++++++++++ pkg/memo/memo_test.go | 10 + testutil/sample/memo.go | 159 +++++++++++++++ 15 files changed, 1901 insertions(+) create mode 100644 pkg/math/bits.go create mode 100644 pkg/math/bits_test.go create mode 100644 pkg/memo/arg.go create mode 100644 pkg/memo/arg_test.go create mode 100644 pkg/memo/codec.go create mode 100644 pkg/memo/codec_abi.go create mode 100644 pkg/memo/codec_abi_test.go create mode 100644 pkg/memo/codec_compact.go create mode 100644 pkg/memo/codec_compact_test.go create mode 100644 pkg/memo/codec_test.go create mode 100644 pkg/memo/fields_v0.go create mode 100644 pkg/memo/memo.go create mode 100644 pkg/memo/memo_test.go create mode 100644 testutil/sample/memo.go diff --git a/go.mod b/go.mod index 7625fc2047..ba1bea9b91 100644 --- a/go.mod +++ b/go.mod @@ -336,6 +336,7 @@ require ( require ( github.com/bnb-chain/tss-lib v1.5.0 github.com/showa-93/go-mask v0.6.2 + github.com/test-go/testify v1.1.4 github.com/tonkeeper/tongo v1.9.3 ) diff --git a/pkg/math/bits.go b/pkg/math/bits.go new file mode 100644 index 0000000000..b00e1539b0 --- /dev/null +++ b/pkg/math/bits.go @@ -0,0 +1,52 @@ +package math + +import ( + "math/bits" +) + +// SetBit sets the bit at the given position (0-7) in the byte to 1 +func SetBit(b *byte, position uint8) { + if position > 7 { + return + } + *b |= 1 << position +} + +// IsBitSet returns true if the bit at the given position (0-7) is set in the byte, false otherwise +func IsBitSet(b byte, position uint8) bool { + if position > 7 { + return false + } + bitMask := byte(1 << position) + return b&bitMask != 0 +} + +// GetBits extracts the value of bits for a given mask +// +// Example: given b = 0b11011001 and mask = 0b11100000, the function returns 0b110 +func GetBits(b byte, mask byte) byte { + extracted := b & mask + + // get the number of trailing zero bits + trailingZeros := bits.TrailingZeros8(mask) + + // remove trailing zeros + return extracted >> trailingZeros +} + +// SetBits sets the value to the bits specified in the mask +// +// Example: given b = 0b00100001 and mask = 0b11100000, and value = 0b110, the function returns 0b11000001 +func SetBits(b byte, mask byte, value byte) byte { + // get the number of trailing zero bits in the mask + trailingZeros := bits.TrailingZeros8(mask) + + // shift the value left by the number of trailing zeros + valueShifted := value << trailingZeros + + // clear the bits in 'b' that correspond to the mask + bCleared := b &^ mask + + // Set the bits by ORing the cleared 'b' with the shifted value + return bCleared | valueShifted +} diff --git a/pkg/math/bits_test.go b/pkg/math/bits_test.go new file mode 100644 index 0000000000..445555a9a2 --- /dev/null +++ b/pkg/math/bits_test.go @@ -0,0 +1,171 @@ +package math_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/math" +) + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + initial byte + position uint8 + expected byte + }{ + { + name: "set bit at position 0", + initial: 0b00001000, + position: 0, + expected: 0b00001001, + }, + { + name: "set bit at position 7", + initial: 0b00001000, + position: 7, + expected: 0b10001000, + }, + { + name: "out of range bit position (no effect)", + initial: 0b00000000, + position: 8, // Out of range + expected: 0b00000000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := tt.initial + math.SetBit(&b, tt.position) + require.Equal(t, tt.expected, b) + }) + } +} + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + b byte + position uint8 + expected bool + }{ + { + name: "bit 0 set", + b: 0b00000001, + position: 0, + expected: true, + }, + { + name: "bit 7 set", + b: 0b10000000, + position: 7, + expected: true, + }, + { + name: "bit 2 not set", + b: 0b00000001, + position: 2, + expected: false, + }, + { + name: "bit out of range", + b: 0b00000001, + position: 8, // Position out of range + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := math.IsBitSet(tt.b, tt.position) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestGetBits(t *testing.T) { + tests := []struct { + name string + b byte + mask byte + expected byte + }{ + { + name: "extract upper 3 bits", + b: 0b11011001, + mask: 0b11100000, + expected: 0b110, + }, + { + name: "extract middle 3 bits", + b: 0b11011001, + mask: 0b00011100, + expected: 0b110, + }, + { + name: "extract lower 3 bits", + b: 0b11011001, + mask: 0b00000111, + expected: 0b001, + }, + { + name: "extract no bits", + b: 0b11011001, + mask: 0b00000000, + expected: 0b000, + }, + { + name: "extract all bits", + b: 0b11111111, + mask: 0b11111111, + expected: 0b11111111, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := math.GetBits(tt.b, tt.mask) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestSetBits(t *testing.T) { + tests := []struct { + name string + b byte + mask byte + value byte + expected byte + }{ + { + name: "set upper 3 bits", + b: 0b00100001, + mask: 0b11100000, + value: 0b110, + expected: 0b11000001, + }, + { + name: "set middle 3 bits", + b: 0b00100001, + mask: 0b00011100, + value: 0b101, + expected: 0b00110101, + }, + { + name: "set lower 3 bits", + b: 0b11111100, + mask: 0b00000111, + value: 0b101, + expected: 0b11111101, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := math.SetBits(tt.b, tt.mask, tt.value) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/memo/arg.go b/pkg/memo/arg.go new file mode 100644 index 0000000000..e29615f562 --- /dev/null +++ b/pkg/memo/arg.go @@ -0,0 +1,52 @@ +package memo + +// ArgType is the enum for types supported by the codec +type ArgType string + +// Define all the types supported by the codec +const ( + ArgTypeBytes ArgType = "bytes" + ArgTypeString ArgType = "string" + ArgTypeAddress ArgType = "address" +) + +// CodecArg represents a codec argument +type CodecArg struct { + Name string + Type ArgType + Arg interface{} +} + +// NewArg create a new codec argument +func NewArg(name string, argType ArgType, arg interface{}) CodecArg { + return CodecArg{ + Name: name, + Type: argType, + Arg: arg, + } +} + +// ArgReceiver wraps the receiver address in a CodecArg +func ArgReceiver(arg interface{}) CodecArg { + return NewArg("receiver", ArgTypeAddress, arg) +} + +// ArgPayload wraps the payload in a CodecArg +func ArgPayload(arg interface{}) CodecArg { + return NewArg("payload", ArgTypeBytes, arg) +} + +// ArgRevertAddress wraps the revert address in a CodecArg +func ArgRevertAddress(arg interface{}) CodecArg { + return NewArg("revertAddress", ArgTypeString, arg) +} + +// ArgAbortAddress wraps the abort address in a CodecArg +func ArgAbortAddress(arg interface{}) CodecArg { + return NewArg("abortAddress", ArgTypeAddress, arg) +} + +// ArgRevertMessage wraps the revert message in a CodecArg +func ArgRevertMessage(arg interface{}) CodecArg { + return NewArg("revertMessage", ArgTypeBytes, arg) +} diff --git a/pkg/memo/arg_test.go b/pkg/memo/arg_test.go new file mode 100644 index 0000000000..5fa9e6a87c --- /dev/null +++ b/pkg/memo/arg_test.go @@ -0,0 +1,70 @@ +package memo_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" +) + +func Test_NewArg(t *testing.T) { + argAddress := sample.EthAddress() + argString := sample.String() + argBytes := sample.Bytes() + + tests := []struct { + name string + argType string + arg interface{} + }{ + { + name: "receiver", + argType: "address", + arg: &argAddress, + }, + { + name: "payload", + argType: "bytes", + arg: &argBytes, + }, + { + name: "revertAddress", + argType: "string", + arg: &argString, + }, + { + name: "abortAddress", + argType: "address", + arg: &argAddress, + }, + { + name: "revertMessage", + argType: "bytes", + arg: &argBytes, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arg := memo.NewArg(tt.name, memo.ArgType(tt.argType), tt.arg) + + require.Equal(t, tt.name, arg.Name) + require.Equal(t, memo.ArgType(tt.argType), arg.Type) + require.Equal(t, tt.arg, arg.Arg) + + switch tt.name { + case "receiver": + require.Equal(t, arg, memo.ArgReceiver(&argAddress)) + case "payload": + require.Equal(t, arg, memo.ArgPayload(&argBytes)) + case "revertAddress": + require.Equal(t, arg, memo.ArgRevertAddress(&argString)) + case "abortAddress": + require.Equal(t, arg, memo.ArgAbortAddress(&argAddress)) + case "revertMessage": + require.Equal(t, arg, memo.ArgRevertMessage(&argBytes)) + } + }) + } +} diff --git a/pkg/memo/codec.go b/pkg/memo/codec.go new file mode 100644 index 0000000000..94539108a3 --- /dev/null +++ b/pkg/memo/codec.go @@ -0,0 +1,64 @@ +package memo + +import ( + "fmt" + + "github.com/pkg/errors" +) + +// Enum for non-EVM chain memo encoding format (2 bits) +const ( + // EncodingFmtABI represents ABI encoding format + EncodingFmtABI uint8 = 0b00 + + // EncodingFmtCompactShort represents 'compact short' encoding format + EncodingFmtCompactShort uint8 = 0b01 + + // EncodingFmtCompactLong represents 'compact long' encoding format + EncodingFmtCompactLong uint8 = 0b10 + + // EncodingFmtMax is the max value of encoding format + EncodingFmtMax uint8 = 0b11 +) + +// Enum for length of bytes used to encode compact data +const ( + LenBytesShort = 1 + LenBytesLong = 2 +) + +// Codec is the interface for a codec +type Codec interface { + // AddArguments adds a list of arguments to the codec + AddArguments(args ...CodecArg) + + // PackArguments packs the arguments into the encoded data + PackArguments() ([]byte, error) + + // UnpackArguments unpacks the encoded data into the arguments + UnpackArguments(data []byte) error +} + +// GetLenBytes returns the number of bytes used to encode the length of the data +func GetLenBytes(encodingFmt uint8) (int, error) { + switch encodingFmt { + case EncodingFmtCompactShort: + return LenBytesShort, nil + case EncodingFmtCompactLong: + return LenBytesLong, nil + default: + return 0, fmt.Errorf("invalid compact encoding format %d", encodingFmt) + } +} + +// GetCodec returns the codec based on the encoding format +func GetCodec(encodingFormat uint8) (Codec, error) { + switch encodingFormat { + case EncodingFmtABI: + return NewCodecABI(), nil + case EncodingFmtCompactShort, EncodingFmtCompactLong: + return NewCodecCompact(encodingFormat) + default: + return nil, errors.New("unsupported encoding format") + } +} diff --git a/pkg/memo/codec_abi.go b/pkg/memo/codec_abi.go new file mode 100644 index 0000000000..9517722006 --- /dev/null +++ b/pkg/memo/codec_abi.go @@ -0,0 +1,92 @@ +package memo + +import ( + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/pkg/errors" +) + +const ( + // ABIAlignment is the number of bytes used to align the ABI encoded data + ABIAlignment = 32 + + // selectorLength is the length of the selector in bytes + selectorLength = 4 + + // codecMethod is the name of the codec method + codecMethod = "codec" + + // codecMethodABIString is the ABI string template for codec method + codecMethodABIString = `[{"name":"codec", "inputs":[%s], "outputs":[%s], "type":"function"}]` +) + +var _ Codec = (*CodecABI)(nil) + +// CodecABI is a coder/decoder for ABI encoded memo fields +type CodecABI struct { + // abiTypes contains the ABI types of the arguments + abiTypes []string + + // abiArgs contains the ABI arguments to be packed or unpacked into + abiArgs []interface{} +} + +// NewCodecABI creates a new ABI codec +func NewCodecABI() *CodecABI { + return &CodecABI{ + abiTypes: make([]string, 0), + abiArgs: make([]interface{}, 0), + } +} + +// AddArguments adds a list of arguments to the codec +func (c *CodecABI) AddArguments(args ...CodecArg) { + for _, arg := range args { + typeJSON := fmt.Sprintf(`{"type":"%s"}`, arg.Type) + c.abiTypes = append(c.abiTypes, typeJSON) + c.abiArgs = append(c.abiArgs, arg.Arg) + } +} + +// PackArguments packs the arguments into the ABI encoded data +func (c *CodecABI) PackArguments() ([]byte, error) { + // get parsed ABI based on the inputs + parsedABI, err := c.parsedABI() + if err != nil { + return nil, errors.Wrap(err, "failed to parse ABI string") + } + + // pack the arguments + data, err := parsedABI.Pack(codecMethod, c.abiArgs...) + if err != nil { + return nil, errors.Wrap(err, "failed to pack ABI arguments") + } + + return data[selectorLength:], nil +} + +// UnpackArguments unpacks the ABI encoded data into the output arguments +func (c *CodecABI) UnpackArguments(data []byte) error { + // get parsed ABI based on the inputs + parsedABI, err := c.parsedABI() + if err != nil { + return errors.Wrap(err, "failed to parse ABI string") + } + + // unpack data into outputs + err = parsedABI.UnpackIntoInterface(&c.abiArgs, codecMethod, data) + if err != nil { + return errors.Wrap(err, "failed to unpack ABI encoded data") + } + + return nil +} + +// parsedABI builds a parsed ABI based on the inputs +func (c *CodecABI) parsedABI() (abi.ABI, error) { + typeList := strings.Join(c.abiTypes, ",") + abiString := fmt.Sprintf(codecMethodABIString, typeList, typeList) + return abi.JSON(strings.NewReader(abiString)) +} diff --git a/pkg/memo/codec_abi_test.go b/pkg/memo/codec_abi_test.go new file mode 100644 index 0000000000..f61f9ac774 --- /dev/null +++ b/pkg/memo/codec_abi_test.go @@ -0,0 +1,227 @@ +package memo_test + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" +) + +// newArgInstance creates a new instance of the given argument type +func newArgInstance(v interface{}) interface{} { + switch v.(type) { + case common.Address: + return new(common.Address) + case []byte: + return &[]byte{} + case string: + return new(string) + } + return nil +} + +// ensureArgEquality ensures the expected argument and actual value are equal +func ensureArgEquality(t *testing.T, expected, actual interface{}) { + switch v := expected.(type) { + case common.Address: + require.Equal(t, v.Hex(), actual.(*common.Address).Hex()) + case []byte: + require.True(t, bytes.Equal(v, *actual.(*[]byte))) + case string: + require.Equal(t, v, *actual.(*string)) + default: + t.Fatalf("unexpected argument type: %T", v) + } +} + +func Test_NewCodecABI(t *testing.T) { + c := memo.NewCodecABI() + require.NotNil(t, c) +} + +func Test_CodecABI_AddArguments(t *testing.T) { + codec := memo.NewCodecABI() + require.NotNil(t, codec) + + address := sample.EthAddress() + codec.AddArguments(memo.ArgReceiver(&address)) +} + +func Test_CodecABI_PackArgument(t *testing.T) { + // create sample arguments + argAddress := sample.EthAddress() + argBytes := []byte("some test bytes argument") + argString := "some test string argument" + + // test cases + tests := []struct { + name string + args []memo.CodecArg + errMsg string + }{ + { + name: "pack in the order of [address, bytes, string]", + args: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "pack in the order of [string, address, bytes]", + args: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "pack empty bytes array and string", + args: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "unable to parse unsupported ABI type", + args: []memo.CodecArg{ + memo.ArgReceiver(&argAddress), + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "failed to parse ABI string", + }, + { + name: "packing should fail on argument type mismatch", + args: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "failed to pack ABI arguments", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new ABI codec and add arguments + codec := memo.NewCodecABI() + codec.AddArguments(tc.args...) + + // pack arguments into ABI-encoded packedData + packedData, err := codec.PackArguments() + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + require.Nil(t, packedData) + return + } + require.NoError(t, err) + + // calc expected data for comparison + expectedData := sample.ABIPack(t, tc.args...) + + // validate the packed data + require.True(t, bytes.Equal(expectedData, packedData), "ABI encoded data mismatch") + }) + } +} + +func Test_CodecABI_UnpackArguments(t *testing.T) { + // create sample arguments + argAddress := sample.EthAddress() + argBytes := []byte("some test bytes argument") + argString := "some test string argument" + + // test cases + tests := []struct { + name string + data []byte + expected []memo.CodecArg + errMsg string + }{ + { + name: "unpack in the order of [address, bytes, string]", + data: sample.ABIPack(t, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString)), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "unpack in the order of [string, address, bytes]", + data: sample.ABIPack(t, + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes)), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "unpack empty bytes array and string", + data: sample.ABIPack(t, + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress("")), + expected: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "unable to parse unsupported ABI type", + data: []byte{}, + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "failed to parse ABI string", + }, + { + name: "unpacking should fail on argument type mismatch", + data: sample.ABIPack(t, + memo.ArgReceiver(argAddress), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "failed to unpack ABI encoded data", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new ABI codec + codec := memo.NewCodecABI() + + // add output arguments + output := make([]memo.CodecArg, len(tc.expected)) + for i, arg := range tc.expected { + output[i] = memo.NewArg(arg.Name, arg.Type, newArgInstance(arg.Arg)) + } + codec.AddArguments(output...) + + // unpack arguments from ABI-encoded data + err := codec.UnpackArguments(tc.data) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // validate the unpacked arguments values + require.NoError(t, err) + for i, arg := range tc.expected { + ensureArgEquality(t, arg.Arg, output[i].Arg) + } + }) + } +} diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go new file mode 100644 index 0000000000..def5213a01 --- /dev/null +++ b/pkg/memo/codec_compact.go @@ -0,0 +1,259 @@ +package memo + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" +) + +var _ Codec = (*CodecCompact)(nil) + +// CodecCompact is a coder/decoder for compact encoded memo fields +type CodecCompact struct { + // lenBytes is the number of bytes used to encode the length of the data + lenBytes int + + // args contains the list of arguments + args []CodecArg +} + +// NewCodecCompact creates a new compact codec +func NewCodecCompact(encodingFmt uint8) (*CodecCompact, error) { + lenBytes, err := GetLenBytes(encodingFmt) + if err != nil { + return nil, err + } + + return &CodecCompact{ + lenBytes: lenBytes, + args: make([]CodecArg, 0), + }, nil +} + +// AddArguments adds a list of arguments to the codec +func (c *CodecCompact) AddArguments(args ...CodecArg) { + c.args = append(c.args, args...) +} + +// PackArguments packs the arguments into the compact encoded data +func (c *CodecCompact) PackArguments() ([]byte, error) { + data := make([]byte, 0) + + // pack according to argument type + for _, arg := range c.args { + switch arg.Type { + case ArgTypeBytes: + dataBytes, err := c.packBytes(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack bytes argument: %s", arg.Name) + } + data = append(data, dataBytes...) + case ArgTypeAddress: + dateAddress, err := c.packAddress(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack address argument: %s", arg.Name) + } + data = append(data, dateAddress...) + case ArgTypeString: + dataString, err := c.packString(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack string argument: %s", arg.Name) + } + data = append(data, dataString...) + default: + return nil, fmt.Errorf("unsupported argument (%s) type: %s", arg.Name, arg.Type) + } + } + + return data, nil +} + +// UnpackArguments unpacks the compact encoded data into the output arguments +func (c *CodecCompact) UnpackArguments(data []byte) error { + // unpack according to argument type + offset := 0 + for _, arg := range c.args { + switch arg.Type { + case ArgTypeBytes: + bytesRead, err := c.unpackBytes(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack bytes argument: %s", arg.Name) + } + offset += bytesRead + case ArgTypeAddress: + bytesRead, err := c.unpackAddress(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack address argument: %s", arg.Name) + } + offset += bytesRead + case ArgTypeString: + bytesRead, err := c.unpackString(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack string argument: %s", arg.Name) + } + offset += bytesRead + default: + return fmt.Errorf("unsupported argument (%s) type: %s", arg.Name, arg.Type) + } + } + + // ensure all data is consumed + if offset != len(data) { + return fmt.Errorf("consumed bytes (%d) != total bytes (%d)", offset, len(data)) + } + + return nil +} + +// packLength packs the length of the data into the compact format +func (c *CodecCompact) packLength(length int) ([]byte, error) { + data := make([]byte, c.lenBytes) + + switch c.lenBytes { + case LenBytesShort: + if length > math.MaxUint8 { + return nil, fmt.Errorf("data length %d exceeds %d bytes", length, math.MaxUint8) + } + data[0] = uint8(length) + case LenBytesLong: + if length > math.MaxUint16 { + return nil, fmt.Errorf("data length %d exceeds %d bytes", length, math.MaxUint16) + } + binary.LittleEndian.PutUint16(data, uint16(length)) + } + return data, nil +} + +// packAddress packs argument of type 'address'. +func (c *CodecCompact) packAddress(arg interface{}) ([]byte, error) { + // type assertion + address, ok := arg.(common.Address) + if !ok { + return nil, fmt.Errorf("argument is not of type common.Address") + } + + return address.Bytes(), nil +} + +// packBytes packs argument of type 'bytes'. +func (c *CodecCompact) packBytes(arg interface{}) ([]byte, error) { + // type assertion + bytes, ok := arg.([]byte) + if !ok { + return nil, fmt.Errorf("argument is not of type []byte") + } + + // pack length of the data + data, err := c.packLength(len(bytes)) + if err != nil { + return nil, errors.Wrap(err, "failed to pack length of bytes") + } + + // append the data + data = append(data, bytes...) + return data, nil +} + +// packString packs argument of type 'string'. +func (c *CodecCompact) packString(arg interface{}) ([]byte, error) { + // type assertion + str, ok := arg.(string) + if !ok { + return nil, fmt.Errorf("argument is not of type string") + } + + // pack length of the data + data, err := c.packLength(len([]byte(str))) + if err != nil { + return nil, errors.Wrap(err, "failed to pack length of string") + } + + // append the string + data = append(data, []byte(str)...) + return data, nil +} + +// unpackLength returns the length of the data encoded in the compact format +func (c *CodecCompact) unpackLength(data []byte) (int, error) { + if len(data) < c.lenBytes { + return 0, fmt.Errorf("expected %d bytes to decode length, got %d", c.lenBytes, len(data)) + } + + // decode length of the data + length := 0 + switch c.lenBytes { + case LenBytesShort: + length = int(data[0]) + case LenBytesLong: + // convert little-endian bytes to integer + length = int(binary.LittleEndian.Uint16(data[:2])) + } + + // ensure remaining data is long enough + if len(data) < c.lenBytes+length { + return 0, fmt.Errorf("expected %d bytes, got %d", length, len(data)-c.lenBytes) + } + + return length, nil +} + +// unpackAddress unpacks argument of type 'address'. +func (c *CodecCompact) unpackAddress(data []byte, output interface{}) (int, error) { + // type assertion + pAddress, ok := output.(*common.Address) + if !ok { + return 0, fmt.Errorf("argument is not of type *common.Address") + } + + // ensure remaining data >= 20 bytes + if len(data) < common.AddressLength { + return 0, fmt.Errorf("expected address, got %d bytes", len(data)) + } + *pAddress = common.BytesToAddress((data[:20])) + + return common.AddressLength, nil +} + +// unpackBytes unpacks argument of type 'bytes' and returns the number of bytes read. +func (c *CodecCompact) unpackBytes(data []byte, output interface{}) (int, error) { + // type assertion + pSlice, ok := output.(*[]byte) + if !ok { + return 0, fmt.Errorf("argument is not of type *[]byte") + } + + // unpack length + dataLen, err := c.unpackLength(data) + if err != nil { + return 0, errors.Wrap(err, "failed to unpack length of bytes") + } + + // make a copy of the data + *pSlice = make([]byte, dataLen) + copy(*pSlice, data[c.lenBytes:c.lenBytes+dataLen]) + + return c.lenBytes + dataLen, nil +} + +// unpackString unpacks argument of type 'string' and returns the number of bytes read. +func (c *CodecCompact) unpackString(data []byte, output interface{}) (int, error) { + // type assertion + pString, ok := output.(*string) + if !ok { + return 0, fmt.Errorf("argument is not of type *string") + } + + // unpack length + strLen, err := c.unpackLength(data) + if err != nil { + return 0, errors.Wrap(err, "failed to unpack length of string") + } + + // make a copy of the string + *pString = string(data[c.lenBytes : c.lenBytes+strLen]) + + return c.lenBytes + strLen, nil +} diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go new file mode 100644 index 0000000000..4df32fe915 --- /dev/null +++ b/pkg/memo/codec_compact_test.go @@ -0,0 +1,350 @@ +package memo_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" +) + +func Test_NewCodecCompact(t *testing.T) { + tests := []struct { + name string + encodingFmt uint8 + fail bool + }{ + { + name: "create codec compact successfully", + encodingFmt: memo.EncodingFmtCompactShort, + }, + { + name: "create codec compact failed on invalid encoding format", + encodingFmt: 0b11, + fail: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + codec, err := memo.NewCodecCompact(tc.encodingFmt) + if tc.fail { + require.Error(t, err) + require.Nil(t, codec) + } else { + require.NoError(t, err) + require.NotNil(t, codec) + } + }) + } +} + +func Test_CodecCompact_AddArguments(t *testing.T) { + codec := memo.NewCodecABI() + require.NotNil(t, codec) + + address := sample.EthAddress() + codec.AddArguments(memo.ArgReceiver(&address)) +} + +func Test_CodecCompact_PackArguments(t *testing.T) { + // create sample arguments + argAddress := sample.EthAddress() + argBytes := []byte("here is a bytes argument") + argString := "some other string argument" + + // test cases + tests := []struct { + name string + encodingFmt uint8 + args []memo.CodecArg + expectedLen int + errMsg string + }{ + { + name: "pack arguments of [address, bytes, string] in compact-short format", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + expectedLen: 20 + 1 + len(argBytes) + 1 + len([]byte(argString)), + }, + { + name: "pack arguments of [string, address, bytes] in compact-long format", + encodingFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + expectedLen: 2 + len([]byte(argString)) + 20 + 2 + len(argBytes), + }, + { + name: "pack long string (> 255 bytes) with compact-long format", + encodingFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), + }, + expectedLen: 2 + 256, + }, + { + name: "pack long string (> 255 bytes) with compact-short format should fail", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), + }, + errMsg: "exceeds 255 bytes", + }, + { + name: "pack long string (> 65535 bytes) with compact-long format should fail", + encodingFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 65536))), + }, + errMsg: "exceeds 65535 bytes", + }, + { + name: "pack empty byte array and string arguments", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + expectedLen: 2, + }, + { + name: "failed to pack bytes argument if string is passed", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload(argString), // expect bytes type, but passed string + }, + errMsg: "argument is not of type []byte", + }, + { + name: "failed to pack address argument if bytes is passed", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "argument is not of type common.Address", + }, + { + name: "failed to pack string argument if bytes is passed", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgRevertAddress(argBytes), // expect string type, but passed bytes + }, + errMsg: "argument is not of type string", + }, + { + name: "failed to pack unsupported argument type", + encodingFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.NewArg("receiver", memo.ArgType("unknown"), nil), + }, + errMsg: "unsupported argument (receiver) type", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new compact codec and add arguments + codec, err := memo.NewCodecCompact(tc.encodingFmt) + require.NoError(t, err) + codec.AddArguments(tc.args...) + + // pack arguments + packedData, err := codec.PackArguments() + + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + require.Nil(t, packedData) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectedLen, len(packedData)) + + // calc expected data for comparison + expectedData := sample.CompactPack(t, tc.encodingFmt, tc.args...) + + // validate the packed data + require.True(t, bytes.Equal(expectedData, packedData), "compact encoded data mismatch") + }) + } +} + +func Test_CodecCompact_UnpackArguments(t *testing.T) { + // create sample arguments + argAddress := sample.EthAddress() + argBytes := []byte("some test bytes argument") + argString := "some other string argument" + + // test cases + tests := []struct { + name string + encodingFmt uint8 + data []byte + expected []memo.CodecArg + errMsg string + }{ + { + name: "unpack arguments of [address, bytes, string] in compact-short format", + encodingFmt: memo.EncodingFmtCompactShort, + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "unpack arguments of [string, address, bytes] in compact-long format", + encodingFmt: memo.EncodingFmtCompactLong, + data: sample.CompactPack(t, + memo.EncodingFmtCompactLong, + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + ), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "unpack empty byte array and string argument", + encodingFmt: memo.EncodingFmtCompactShort, + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + ), + expected: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "failed to unpack string if data length < 1 byte", + encodingFmt: memo.EncodingFmtCompactShort, + data: []byte{}, + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + }, + errMsg: "expected 1 bytes to decode length", + }, + { + name: "failed to unpack string if actual data is less than decoded length", + encodingFmt: memo.EncodingFmtCompactShort, + data: []byte{0x05, 0x0a, 0x0b, 0x0c, 0x0d}, // length = 5, but only 4 bytes provided + expected: []memo.CodecArg{ + memo.ArgPayload(argBytes), + }, + errMsg: "expected 5 bytes, got 4", + }, + { + name: "failed to unpack bytes argument if string is passed", + encodingFmt: memo.EncodingFmtCompactShort, + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgPayload(argBytes), + ), + expected: []memo.CodecArg{ + memo.ArgPayload(argString), // expect bytes type, but passed string + }, + errMsg: "argument is not of type *[]byte", + }, + { + name: "failed to unpack address argument if bytes is passed", + encodingFmt: memo.EncodingFmtCompactShort, + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "argument is not of type *common.Address", + }, + { + name: "failed to unpack string argument if address is passed", + encodingFmt: memo.EncodingFmtCompactShort, + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgRevertAddress(argString), + ), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argAddress), // expect string type, but passed address + }, + errMsg: "argument is not of type *string", + }, + { + name: "failed to unpack unsupported argument type", + encodingFmt: memo.EncodingFmtCompactShort, + data: []byte{}, + expected: []memo.CodecArg{ + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "unsupported argument (payload) type", + }, + { + name: "unpacking should fail if not all data is consumed", + encodingFmt: memo.EncodingFmtCompactShort, + data: func() []byte { + data := sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + ) + // append 1 extra byte + return append(data, 0x00) + }(), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + errMsg: "consumed bytes (45) != total bytes (46)", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new compact codec and add arguments + codec, err := memo.NewCodecCompact(tc.encodingFmt) + require.NoError(t, err) + + // add output arguments + output := make([]memo.CodecArg, len(tc.expected)) + for i, arg := range tc.expected { + output[i] = memo.NewArg(arg.Name, arg.Type, newArgInstance(arg.Arg)) + } + codec.AddArguments(output...) + + // unpack arguments from compact-encoded data + err = codec.UnpackArguments(tc.data) + + // validate error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // validate the unpacked arguments values + require.NoError(t, err) + for i, arg := range tc.expected { + ensureArgEquality(t, arg.Arg, output[i].Arg) + } + }) + } +} diff --git a/pkg/memo/codec_test.go b/pkg/memo/codec_test.go new file mode 100644 index 0000000000..27eb73d977 --- /dev/null +++ b/pkg/memo/codec_test.go @@ -0,0 +1,92 @@ +package memo_test + +import ( + "testing" + + "github.com/test-go/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +func Test_GetLenBytes(t *testing.T) { + // Define table-driven test cases + tests := []struct { + name string + encodingFmt uint8 + expectedLen int + expectErr bool + }{ + { + name: "compact short", + encodingFmt: memo.EncodingFmtCompactShort, + expectedLen: 1, + }, + { + name: "compact long", + encodingFmt: memo.EncodingFmtCompactLong, + expectedLen: 2, + }, + { + name: "non-compact encoding format", + encodingFmt: memo.EncodingFmtABI, + expectedLen: 0, + expectErr: true, + }, + } + + // Loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + length, err := memo.GetLenBytes(tc.encodingFmt) + + // Check if error is expected + if tc.expectErr { + require.Error(t, err) + require.Equal(t, 0, length) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedLen, length) + } + }) + } +} + +func Test_GetCodec(t *testing.T) { + // Define table-driven test cases + tests := []struct { + name string + encodingFmt uint8 + errMsg string + }{ + { + name: "should get ABI codec", + encodingFmt: memo.EncodingFmtABI, + }, + { + name: "should get compact codec", + encodingFmt: memo.EncodingFmtCompactShort, + }, + { + name: "should get compact codec", + encodingFmt: memo.EncodingFmtCompactLong, + }, + { + name: "should fail to get codec", + encodingFmt: 0b11, + errMsg: "unsupported", + }, + } + + // Loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + codec, err := memo.GetCodec(tc.encodingFmt) + if tc.errMsg != "" { + require.Error(t, err) + require.Nil(t, codec) + } else { + require.NoError(t, err) + require.NotNil(t, codec) + } + }) + } +} diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go new file mode 100644 index 0000000000..2561213351 --- /dev/null +++ b/pkg/memo/fields_v0.go @@ -0,0 +1,133 @@ +package memo + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" + + zetamath "github.com/zeta-chain/node/pkg/math" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" +) + +// Enum of the bit position of each memo fields +const ( + bitPosPayload uint8 = 0 // payload + bitPosRevertAddress uint8 = 1 // revertAddress + bitPosAbortAddress uint8 = 2 // abortAddress + bitPosRevertMessage uint8 = 3 // revertMessage +) + +// FieldsV0 contains the data fields of the inbound memo V0 +type FieldsV0 struct { + // Receiver is the ZEVM receiver address + Receiver common.Address + + // Payload is the calldata passed to ZEVM contract call + Payload []byte + + // RevertOptions is the options for cctx revert handling + RevertOptions *crosschaintypes.RevertOptions +} + +// FieldsEncoderV0 is the encoder for outbound memo fields V0 +func FieldsEncoderV0(memo *InboundMemo) ([]byte, error) { + codec, err := GetCodec(memo.EncodingFormat) + if err != nil { + return nil, errors.Wrap(err, "unable to get codec") + } + + return PackMemoFieldsV0(codec, memo) +} + +// FieldsDecoderV0 is the decoder for inbound memo fields V0 +func FieldsDecoderV0(data []byte, memo *InboundMemo) error { + codec, err := GetCodec(memo.EncodingFormat) + if err != nil { + return errors.Wrap(err, "unable to get codec") + } + return UnpackMemoFieldsV0(codec, data[HeaderSize:], memo) +} + +// PackMemoFieldsV0 packs the memo fields for version 0 +func PackMemoFieldsV0(codec Codec, memo *InboundMemo) ([]byte, error) { + // create data flags byte + dataFlags := byte(0) + + // add 'receiver' as the first argument + codec.AddArguments(ArgReceiver(memo.Receiver)) + + // add 'payload' argument optionally + if len(memo.Payload) > 0 { + zetamath.SetBit(&dataFlags, bitPosPayload) + codec.AddArguments(ArgPayload(memo.Payload)) + } + + if memo.RevertOptions != nil { + // add 'revertAddress' argument optionally + if memo.RevertOptions.RevertAddress != "" { + zetamath.SetBit(&dataFlags, bitPosRevertAddress) + codec.AddArguments(ArgRevertAddress(memo.RevertOptions.RevertAddress)) + } + + // add 'abortAddress' argument optionally + if memo.RevertOptions.AbortAddress != "" { + zetamath.SetBit(&dataFlags, bitPosAbortAddress) + codec.AddArguments(ArgAbortAddress(common.HexToAddress(memo.RevertOptions.AbortAddress))) + } + + // add 'revertMessage' argument optionally + if memo.RevertOptions.CallOnRevert { + zetamath.SetBit(&dataFlags, bitPosRevertMessage) + codec.AddArguments(ArgRevertMessage(memo.RevertOptions.RevertMessage)) + } + } + + // pack the codec arguments into data + data, err := codec.PackArguments() + if err != nil { + return nil, err + } + + return append([]byte{dataFlags}, data...), nil +} + +// UnpackMemoFieldsV0 unpacks the memo fields for version 0 +func UnpackMemoFieldsV0(codec Codec, data []byte, memo *InboundMemo) error { + // byte-2 contains data flags + dataFlags := data[2] + + // add 'receiver' as the first argument + codec.AddArguments(ArgReceiver(&memo.Receiver)) + + // add 'payload' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosPayload) { + codec.AddArguments(ArgPayload(&memo.Payload)) + } + + // add 'revertAddress' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosRevertAddress) { + codec.AddArguments(ArgRevertAddress(&memo.RevertOptions.RevertAddress)) + } + + // add 'abortAddress' argument optionally + var abortAddress common.Address + if zetamath.IsBitSet(dataFlags, bitPosRevertMessage) { + codec.AddArguments(ArgAbortAddress(&abortAddress)) + } + + // add 'revertMessage' argument optionally + memo.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosAbortAddress) + if memo.RevertOptions.CallOnRevert { + codec.AddArguments(ArgRevertMessage(&memo.RevertOptions.RevertMessage)) + } + + // unpack the data (after header) into codec arguments + err := codec.UnpackArguments(data[HeaderSize:]) + if err != nil { + return err + } + + // convert abort address to string + memo.RevertOptions.AbortAddress = abortAddress.Hex() + + return nil +} diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go new file mode 100644 index 0000000000..bbe785e3ce --- /dev/null +++ b/pkg/memo/memo.go @@ -0,0 +1,169 @@ +package memo + +import ( + "fmt" + + "github.com/pkg/errors" + + zetamath "github.com/zeta-chain/node/pkg/math" +) + +const ( + // MemoIdentifier is the ASCII code of 'Z' (0x5A) + MemoIdentifier byte = 0x5A + + // HeaderSize is the size of the memo header + HeaderSize = 3 + + // MaskVersion is the mask for the version bits(5~7) + BitMaskVersion byte = 0b11100000 + + // BitMaskEncodingFormat is the mask for the encoding format bits(3~4) + BitMaskEncodingFormat byte = 0b00011000 + + // BitMaskOpCode is the mask for the operation code bits(0~2) + BitMaskOpCode byte = 0b00000111 +) + +// Enum for non-EVM chain inbound operation code (3 bits) +const ( + InboundOpCodeDeposit uint8 = 0b000 // operation 'deposit' + InboundOpCodeDepositAndCall uint8 = 0b001 // operation 'deposit_and_call' + InboundOpCodeCall uint8 = 0b010 // operation 'call' + InboundOpCodeMax uint8 = 0b011 // operation max value +) + +// Encoder is the interface for outbound memo encoders +type Encoder func(memo *InboundMemo) ([]byte, error) + +// Decoder is the interface for inbound memo decoders +type Decoder func(data []byte, memo *InboundMemo) error + +// memoEncoderRegistry contains all registered memo encoders +var memoEncoderRegistry = map[uint8]Encoder{ + 0: FieldsEncoderV0, +} + +// memoDecoderRegistry contains all registered memo decoders +var memoDecoderRegistry = map[uint8]Decoder{ + 0: FieldsDecoderV0, +} + +// InboundMemo represents the memo structure for non-EVM chains +type InboundMemo struct { + // Version is the memo Version + Version uint8 + + // EncodingFormat is the memo encoding format + EncodingFormat uint8 + + // OpCode is the inbound operation code + OpCode uint8 + + // FieldsV0 contains the memo fields V0 + // Note: add a MemoFieldsV1 if major change is needed in the future + FieldsV0 +} + +// EncodeMemoToBytes encodes a InboundMemo struct to raw bytes +func EncodeMemoToBytes(memo *InboundMemo) ([]byte, error) { + // get registered memo encoder by version + encoder, found := memoEncoderRegistry[memo.Version] + if !found { + return nil, fmt.Errorf("encoder not found for memo version: %d", memo.Version) + } + + // encode memo basics + basics := EncodeMemoBasics(memo) + + // encode memo fields using the encoder + data, err := encoder(memo) + if err != nil { + return nil, errors.Wrap(err, "failed to encode memo fields") + } + + return append(basics, data...), nil +} + +// EncodeMemoBasics encodes the version, encoding format and operation code +func EncodeMemoBasics(memo *InboundMemo) []byte { + // create 3-byte header + head := make([]byte, HeaderSize) + + // set byte-0 as memo identifier + head[0] = MemoIdentifier + + // set version # and encoding format + var ctrlByte byte + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskVersion, memo.Version) + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskEncodingFormat, memo.EncodingFormat) + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskOpCode, memo.OpCode) + + // set ctrlByte to byte-1 + head[1] = ctrlByte + + return head +} + +// DecodeMemoFromBytes decodes a InboundMemo struct from raw bytes +// +// Returns an error if given data is not a valid memo +func DecodeMemoFromBytes(data []byte) (*InboundMemo, error) { + memo := &InboundMemo{} + + // decode memo basics + err := DecodeMemoBasics(data, memo) + if err != nil { + return nil, err + } + + // get registered memo decoder by version + decoder, found := memoDecoderRegistry[memo.Version] + if !found { + return nil, fmt.Errorf("decoder not found for memo version: %d", memo.Version) + } + + // decode memo fields using the decoer + err = decoder(data, memo) + if err != nil { + return nil, errors.Wrap(err, "failed to decode memo fields") + } + + return memo, nil +} + +// DecodeMemoBasics decodes version, encoding format and operation code +func DecodeMemoBasics(data []byte, memo *InboundMemo) error { + // memo data must be longer than the header size + if len(data) <= HeaderSize { + return errors.New("memo data too short") + } + + // byte-0 is the memo identifier + if data[0] != MemoIdentifier { + return errors.New("memo identifier mismatch") + } + + // byte-1 is the control byte + ctrlByte := data[1] + + // extract version # + memo.Version = zetamath.GetBits(ctrlByte, BitMaskVersion) + if memo.Version != 0 { + return fmt.Errorf("unsupported memo version: %d", memo.Version) + } + + // extract encoding format + memo.EncodingFormat = zetamath.GetBits(ctrlByte, BitMaskEncodingFormat) + if memo.EncodingFormat >= EncodingFmtMax { + return fmt.Errorf("invalid encoding format: %d", memo.EncodingFormat) + } + + // extract operation code + memo.OpCode = zetamath.GetBits(ctrlByte, BitMaskOpCode) + if memo.OpCode >= InboundOpCodeMax { + return fmt.Errorf("invalid operation code: %d", memo.OpCode) + } + + return nil +} diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go new file mode 100644 index 0000000000..add6944fdf --- /dev/null +++ b/pkg/memo/memo_test.go @@ -0,0 +1,10 @@ +package memo_test + +import ( + "encoding/hex" + "testing" +) + +func Test_NewInboundMemo(t *testing.T) { + hex.EncodeToString([]byte("hello")) +} diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go new file mode 100644 index 0000000000..17f080695f --- /dev/null +++ b/testutil/sample/memo.go @@ -0,0 +1,159 @@ +package sample + +import ( + "encoding/binary" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + + "github.com/zeta-chain/node/pkg/memo" +) + +// ABIPack is a helper function to simulates the abi.Pack function. +// Note: all arguments are assumed to be <= 32 bytes for simplicity. +func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // data offset for 1st dynamic-length field + offset := memo.ABIAlignment * len(args) + + // 1. pack 32-byte offset for each dynamic-length field (bytes, string) + // 2. pack actual data for each fixed-length field (address) + for _, arg := range args { + switch arg.Type { + case memo.ArgTypeBytes: // left-pad for uint8 + offsetData := abiPad32(t, []byte{byte(offset)}, true) + packedData = append(packedData, offsetData...) + + argLen := len(arg.Arg.([]byte)) + if argLen > 0 { + offset += memo.ABIAlignment * 2 // [length + data] + } else { + offset += memo.ABIAlignment // only [length] + } + + case memo.ArgTypeString: // left-pad for uint8 + offsetData := abiPad32(t, []byte{byte(offset)}, true) + packedData = append(packedData, offsetData...) + + argLen := len([]byte(arg.Arg.(string))) + if argLen > 0 { + offset += memo.ABIAlignment * 2 // [length + data] + } else { + offset += memo.ABIAlignment // only [length] + } + + case memo.ArgTypeAddress: // left-pad for address + data := abiPad32(t, arg.Arg.(common.Address).Bytes(), true) + packedData = append(packedData, data...) + } + } + + // pack dynamic-length fields + dynamicData := abiPackDynamicData(t, args...) + packedData = append(packedData, dynamicData...) + + return packedData +} + +// CompactPack is a helper function to pack arguments into compact encoded data +// Note: all arguments are assumed to be <= 65535 bytes for simplicity. +func CompactPack(_ *testing.T, encodingFmt uint8, args ...memo.CodecArg) []byte { + var ( + length int + packedData []byte + ) + + for _, arg := range args { + // get length of argument + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + // skip length for other types + length = -1 + } + + // append length in bytes + if length != -1 { + switch encodingFmt { + case memo.EncodingFmtCompactShort: + packedData = append(packedData, byte(length)) + case memo.EncodingFmtCompactLong: + buff := make([]byte, 2) + binary.LittleEndian.PutUint16(buff, uint16(length)) + packedData = append(packedData, buff...) + } + } + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: + packedData = append(packedData, arg.Arg.([]byte)...) + case memo.ArgTypeAddress: + packedData = append(packedData, arg.Arg.(common.Address).Bytes()...) + case memo.ArgTypeString: + packedData = append(packedData, []byte(arg.Arg.(string))...) + } + } + + return packedData +} + +// abiPad32 is a helper function to pad a byte slice to 32 bytes +func abiPad32(t *testing.T, data []byte, left bool) []byte { + // nothing needs to be encoded, return empty bytes + if len(data) == 0 { + return []byte{} + } + + require.LessOrEqual(t, len(data), memo.ABIAlignment) + padded := make([]byte, 32) + + if left { + // left-pad the data for fixed-size types + copy(padded[32-len(data):], data) + } else { + // right-pad the data for dynamic types + copy(padded, data) + } + return padded +} + +// apiPackDynamicData is a helper function to pack dynamic-length data +func abiPackDynamicData(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // pack with ABI format: length + data + for _, arg := range args { + // get length + var length int + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + continue + } + + // append length in bytes + lengthData := abiPad32(t, []byte{byte(length)}, true) + packedData = append(packedData, lengthData...) + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: // right-pad for bytes + data := abiPad32(t, arg.Arg.([]byte), false) + packedData = append(packedData, data...) + case memo.ArgTypeString: // right-pad for string + data := abiPad32(t, []byte(arg.Arg.(string)), false) + packedData = append(packedData, data...) + } + } + + return packedData +} From 95acf7c02843be14b8091507d8fc2cfa3b97f90b Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Wed, 9 Oct 2024 01:55:50 -0500 Subject: [PATCH 02/19] add unit tests against memo fields version 0 --- pkg/memo/fields.go | 10 ++ pkg/memo/fields_v0.go | 101 +++++++++-------- pkg/memo/fields_v0_test.go | 221 +++++++++++++++++++++++++++++++++++++ pkg/memo/memo.go | 110 ++++++++---------- testutil/sample/memo.go | 14 ++- 5 files changed, 341 insertions(+), 115 deletions(-) create mode 100644 pkg/memo/fields.go create mode 100644 pkg/memo/fields_v0_test.go diff --git a/pkg/memo/fields.go b/pkg/memo/fields.go new file mode 100644 index 0000000000..99118a1da9 --- /dev/null +++ b/pkg/memo/fields.go @@ -0,0 +1,10 @@ +package memo + +// Fields is the interface for memo fields +type Fields interface { + // Pack encodes the memo fields + Pack(encodingFormat uint8) ([]byte, error) + + // Unpack decodes the memo fields + Unpack(data []byte, encodingFormat uint8) error +} diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index 2561213351..c3082417fa 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -4,6 +4,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "github.com/zeta-chain/node/pkg/crypto" zetamath "github.com/zeta-chain/node/pkg/math" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) @@ -16,6 +17,8 @@ const ( bitPosRevertMessage uint8 = 3 // revertMessage ) +var _ Fields = (*FieldsV0)(nil) + // FieldsV0 contains the data fields of the inbound memo V0 type FieldsV0 struct { // Receiver is the ZEVM receiver address @@ -25,87 +28,87 @@ type FieldsV0 struct { Payload []byte // RevertOptions is the options for cctx revert handling - RevertOptions *crosschaintypes.RevertOptions + RevertOptions crosschaintypes.RevertOptions } -// FieldsEncoderV0 is the encoder for outbound memo fields V0 -func FieldsEncoderV0(memo *InboundMemo) ([]byte, error) { - codec, err := GetCodec(memo.EncodingFormat) +// Pack encodes the memo fields +func (f *FieldsV0) Pack(encodingFormat uint8) ([]byte, error) { + codec, err := GetCodec(encodingFormat) if err != nil { return nil, errors.Wrap(err, "unable to get codec") } - return PackMemoFieldsV0(codec, memo) + return f.packFields(codec) } -// FieldsDecoderV0 is the decoder for inbound memo fields V0 -func FieldsDecoderV0(data []byte, memo *InboundMemo) error { - codec, err := GetCodec(memo.EncodingFormat) +// Unpack decodes the memo fields +func (f *FieldsV0) Unpack(data []byte, encodingFormat uint8) error { + codec, err := GetCodec(encodingFormat) if err != nil { return errors.Wrap(err, "unable to get codec") } - return UnpackMemoFieldsV0(codec, data[HeaderSize:], memo) + + return f.unpackFields(codec, data) } -// PackMemoFieldsV0 packs the memo fields for version 0 -func PackMemoFieldsV0(codec Codec, memo *InboundMemo) ([]byte, error) { +// packFieldsV0 packs the memo fields for version 0 +func (f *FieldsV0) packFields(codec Codec) ([]byte, error) { // create data flags byte dataFlags := byte(0) // add 'receiver' as the first argument - codec.AddArguments(ArgReceiver(memo.Receiver)) + codec.AddArguments(ArgReceiver(f.Receiver)) // add 'payload' argument optionally - if len(memo.Payload) > 0 { + if len(f.Payload) > 0 { zetamath.SetBit(&dataFlags, bitPosPayload) - codec.AddArguments(ArgPayload(memo.Payload)) + codec.AddArguments(ArgPayload(f.Payload)) + } + + // add 'revertAddress' argument optionally + if f.RevertOptions.RevertAddress != "" { + zetamath.SetBit(&dataFlags, bitPosRevertAddress) + codec.AddArguments(ArgRevertAddress(f.RevertOptions.RevertAddress)) } - if memo.RevertOptions != nil { - // add 'revertAddress' argument optionally - if memo.RevertOptions.RevertAddress != "" { - zetamath.SetBit(&dataFlags, bitPosRevertAddress) - codec.AddArguments(ArgRevertAddress(memo.RevertOptions.RevertAddress)) - } - - // add 'abortAddress' argument optionally - if memo.RevertOptions.AbortAddress != "" { - zetamath.SetBit(&dataFlags, bitPosAbortAddress) - codec.AddArguments(ArgAbortAddress(common.HexToAddress(memo.RevertOptions.AbortAddress))) - } - - // add 'revertMessage' argument optionally - if memo.RevertOptions.CallOnRevert { - zetamath.SetBit(&dataFlags, bitPosRevertMessage) - codec.AddArguments(ArgRevertMessage(memo.RevertOptions.RevertMessage)) - } + // add 'abortAddress' argument optionally + abortAddress := common.HexToAddress(f.RevertOptions.AbortAddress) + if !crypto.IsEmptyAddress(abortAddress) { + zetamath.SetBit(&dataFlags, bitPosAbortAddress) + codec.AddArguments(ArgAbortAddress(abortAddress)) + } + + // add 'revertMessage' argument optionally + if f.RevertOptions.CallOnRevert { + zetamath.SetBit(&dataFlags, bitPosRevertMessage) + codec.AddArguments(ArgRevertMessage(f.RevertOptions.RevertMessage)) } // pack the codec arguments into data data, err := codec.PackArguments() - if err != nil { - return nil, err + if err != nil { // never happens + return nil, errors.Wrap(err, "failed to pack arguments") } return append([]byte{dataFlags}, data...), nil } -// UnpackMemoFieldsV0 unpacks the memo fields for version 0 -func UnpackMemoFieldsV0(codec Codec, data []byte, memo *InboundMemo) error { - // byte-2 contains data flags - dataFlags := data[2] +// unpackFields unpacks the memo fields for version 0 +func (f *FieldsV0) unpackFields(codec Codec, data []byte) error { + // get data flags + dataFlags := data[0] // add 'receiver' as the first argument - codec.AddArguments(ArgReceiver(&memo.Receiver)) + codec.AddArguments(ArgReceiver(&f.Receiver)) // add 'payload' argument optionally if zetamath.IsBitSet(dataFlags, bitPosPayload) { - codec.AddArguments(ArgPayload(&memo.Payload)) + codec.AddArguments(ArgPayload(&f.Payload)) } // add 'revertAddress' argument optionally if zetamath.IsBitSet(dataFlags, bitPosRevertAddress) { - codec.AddArguments(ArgRevertAddress(&memo.RevertOptions.RevertAddress)) + codec.AddArguments(ArgRevertAddress(&f.RevertOptions.RevertAddress)) } // add 'abortAddress' argument optionally @@ -115,19 +118,21 @@ func UnpackMemoFieldsV0(codec Codec, data []byte, memo *InboundMemo) error { } // add 'revertMessage' argument optionally - memo.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosAbortAddress) - if memo.RevertOptions.CallOnRevert { - codec.AddArguments(ArgRevertMessage(&memo.RevertOptions.RevertMessage)) + f.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosAbortAddress) + if f.RevertOptions.CallOnRevert { + codec.AddArguments(ArgRevertMessage(&f.RevertOptions.RevertMessage)) } - // unpack the data (after header) into codec arguments - err := codec.UnpackArguments(data[HeaderSize:]) + // unpack the data (after flags) into codec arguments + err := codec.UnpackArguments(data[1:]) if err != nil { - return err + return errors.Wrap(err, "failed to unpack arguments") } // convert abort address to string - memo.RevertOptions.AbortAddress = abortAddress.Hex() + if !crypto.IsEmptyAddress(abortAddress) { + f.RevertOptions.AbortAddress = abortAddress.Hex() + } return nil } diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go new file mode 100644 index 0000000000..eac811d11d --- /dev/null +++ b/pkg/memo/fields_v0_test.go @@ -0,0 +1,221 @@ +package memo_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" +) + +func Test_V0_Pack(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + encodingFormat uint8 + fields memo.FieldsV0 + expectedFlags byte + expectedData []byte + errMsg string + }{ + { + name: "pack all fields with ABI encoding", + encodingFormat: memo.EncodingFmtABI, + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + expectedFlags: 0b00001111, // all fields are set + expectedData: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "pack all fields with compact encoding", + encodingFormat: memo.EncodingFmtCompactShort, + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + expectedFlags: 0b00001111, // all fields are set + expectedData: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "should not pack invalid abort address", + encodingFormat: memo.EncodingFmtABI, + fields: memo.FieldsV0{ + Receiver: fAddress, + RevertOptions: crosschaintypes.RevertOptions{ + AbortAddress: "invalid_address", + }, + }, + expectedFlags: 0b00000000, // no flag is set + expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress)), + }, + { + name: "unable to get codec on invalid encoding format", + encodingFormat: 0x0F, + errMsg: "unable to get codec", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // pack the fields + data, err := tc.fields.Pack(tc.encodingFormat) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // compare the fields + require.NoError(t, err) + require.Equal(t, tc.expectedFlags, data[0]) + require.True(t, bytes.Equal(tc.expectedData, data[1:])) + }) + } +} + +func Test_V0_Unpack(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + encodingFormat uint8 + flags byte + data []byte + expected memo.FieldsV0 + errMsg string + }{ + { + name: "unpack all fields with ABI encoding", + encodingFormat: memo.EncodingFmtABI, + flags: 0b00001111, // all fields are set + data: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "unpack all fields with compact encoding", + encodingFormat: memo.EncodingFmtCompactShort, + flags: 0b00001111, // all fields are set + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "unpack empty ABI encoded payload if flag is set", + encodingFormat: memo.EncodingFmtABI, + flags: 0b00000001, // payload flag is set + data: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload([]byte{})), // empty payload + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: []byte{}, + }, + }, + { + name: "unpack empty compact encoded payload if flag is not set", + encodingFormat: memo.EncodingFmtCompactShort, + flags: 0b00000001, // payload flag is set + data: sample.CompactPack(t, + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload([]byte{})), // empty payload + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: []byte{}, + }, + }, + { + name: "failed to unpack ABI encoded data with compact encoding format", + encodingFormat: memo.EncodingFmtCompactShort, + flags: 0b00000001, + data: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes)), + errMsg: "failed to unpack arguments", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // attach data flags + tc.data = append([]byte{tc.flags}, tc.data...) + + // unpack the fields + fields := memo.FieldsV0{} + err := fields.Unpack(tc.data, tc.encodingFormat) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // compare the fields + require.NoError(t, err) + require.Equal(t, tc.expected, fields) + }) + } +} diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index bbe785e3ce..fa2b4570f3 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -33,22 +33,6 @@ const ( InboundOpCodeMax uint8 = 0b011 // operation max value ) -// Encoder is the interface for outbound memo encoders -type Encoder func(memo *InboundMemo) ([]byte, error) - -// Decoder is the interface for inbound memo decoders -type Decoder func(data []byte, memo *InboundMemo) error - -// memoEncoderRegistry contains all registered memo encoders -var memoEncoderRegistry = map[uint8]Encoder{ - 0: FieldsEncoderV0, -} - -// memoDecoderRegistry contains all registered memo decoders -var memoDecoderRegistry = map[uint8]Decoder{ - 0: FieldsDecoderV0, -} - // InboundMemo represents the memo structure for non-EVM chains type InboundMemo struct { // Version is the memo Version @@ -61,79 +45,79 @@ type InboundMemo struct { OpCode uint8 // FieldsV0 contains the memo fields V0 - // Note: add a MemoFieldsV1 if major change is needed in the future + // Note: add a FieldsV1 if major update is needed in the future FieldsV0 } -// EncodeMemoToBytes encodes a InboundMemo struct to raw bytes -func EncodeMemoToBytes(memo *InboundMemo) ([]byte, error) { - // get registered memo encoder by version - encoder, found := memoEncoderRegistry[memo.Version] - if !found { - return nil, fmt.Errorf("encoder not found for memo version: %d", memo.Version) - } - +// EncodeToBytes encodes a InboundMemo struct to raw bytes +func EncodeToBytes(memo *InboundMemo) ([]byte, error) { // encode memo basics - basics := EncodeMemoBasics(memo) - - // encode memo fields using the encoder - data, err := encoder(memo) + basics := encodeBasics(memo) + + // encode memo fields based on version + var data []byte + var err error + switch memo.Version { + case 0: + data, err = memo.FieldsV0.Pack(memo.EncodingFormat) + default: + return nil, fmt.Errorf("unsupported memo version: %d", memo.Version) + } if err != nil { - return nil, errors.Wrap(err, "failed to encode memo fields") + return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", memo.Version) } return append(basics, data...), nil } -// EncodeMemoBasics encodes the version, encoding format and operation code -func EncodeMemoBasics(memo *InboundMemo) []byte { - // create 3-byte header - head := make([]byte, HeaderSize) - - // set byte-0 as memo identifier - head[0] = MemoIdentifier - - // set version # and encoding format - var ctrlByte byte - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskVersion, memo.Version) - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskEncodingFormat, memo.EncodingFormat) - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskOpCode, memo.OpCode) - - // set ctrlByte to byte-1 - head[1] = ctrlByte - - return head -} - -// DecodeMemoFromBytes decodes a InboundMemo struct from raw bytes +// DecodeFromBytes decodes a InboundMemo struct from raw bytes // // Returns an error if given data is not a valid memo -func DecodeMemoFromBytes(data []byte) (*InboundMemo, error) { +func DecodeFromBytes(data []byte) (*InboundMemo, error) { memo := &InboundMemo{} // decode memo basics - err := DecodeMemoBasics(data, memo) + err := decodeBasics(data, memo) if err != nil { return nil, err } - // get registered memo decoder by version - decoder, found := memoDecoderRegistry[memo.Version] - if !found { - return nil, fmt.Errorf("decoder not found for memo version: %d", memo.Version) + // decode memo fields based on version + switch memo.Version { + case 0: + err = memo.FieldsV0.Unpack(data, memo.EncodingFormat) + default: + return nil, fmt.Errorf("unsupported memo version: %d", memo.Version) } - - // decode memo fields using the decoer - err = decoder(data, memo) if err != nil { - return nil, errors.Wrap(err, "failed to decode memo fields") + return nil, errors.Wrapf(err, "failed to unpack memo fields version: %d", memo.Version) } return memo, nil } -// DecodeMemoBasics decodes version, encoding format and operation code -func DecodeMemoBasics(data []byte, memo *InboundMemo) error { +// encodeBasics encodes the version, encoding format and operation code +func encodeBasics(memo *InboundMemo) []byte { + // 2 bytes: [identifier + ctrlByte] + basics := make([]byte, HeaderSize-1) + + // set byte-0 as memo identifier + basics[0] = MemoIdentifier + + // set version # and encoding format + var ctrlByte byte + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskVersion, memo.Version) + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskEncodingFormat, memo.EncodingFormat) + ctrlByte = zetamath.SetBits(ctrlByte, BitMaskOpCode, memo.OpCode) + + // set ctrlByte to byte-1 + basics[1] = ctrlByte + + return basics +} + +// decodeBasics decodes version, encoding format and operation code +func decodeBasics(data []byte, memo *InboundMemo) error { // memo data must be longer than the header size if len(data) <= HeaderSize { return errors.New("memo data too short") diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index 17f080695f..f34354c392 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -22,8 +22,11 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { // 2. pack actual data for each fixed-length field (address) for _, arg := range args { switch arg.Type { - case memo.ArgTypeBytes: // left-pad for uint8 - offsetData := abiPad32(t, []byte{byte(offset)}, true) + case memo.ArgTypeBytes: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) packedData = append(packedData, offsetData...) argLen := len(arg.Arg.([]byte)) @@ -33,8 +36,11 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { offset += memo.ABIAlignment // only [length] } - case memo.ArgTypeString: // left-pad for uint8 - offsetData := abiPad32(t, []byte{byte(offset)}, true) + case memo.ArgTypeString: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) packedData = append(packedData, offsetData...) argLen := len([]byte(arg.Arg.(string))) From bde439f7bf18ee211bf63f6087019f920ed09ba8 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Wed, 9 Oct 2024 15:51:50 -0500 Subject: [PATCH 03/19] added memo unit tests --- pkg/memo/codec_compact_test.go | 16 +- pkg/memo/fields_v0.go | 17 +- pkg/memo/fields_v0_test.go | 15 +- pkg/memo/memo.go | 146 ++++++++------ pkg/memo/memo_test.go | 350 ++++++++++++++++++++++++++++++++- testutil/sample/memo.go | 13 +- 6 files changed, 482 insertions(+), 75 deletions(-) diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go index 4df32fe915..90f9127c08 100644 --- a/pkg/memo/codec_compact_test.go +++ b/pkg/memo/codec_compact_test.go @@ -169,7 +169,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { require.Equal(t, tc.expectedLen, len(packedData)) // calc expected data for comparison - expectedData := sample.CompactPack(t, tc.encodingFmt, tc.args...) + expectedData := sample.CompactPack(tc.encodingFmt, tc.args...) // validate the packed data require.True(t, bytes.Equal(expectedData, packedData), "compact encoded data mismatch") @@ -194,7 +194,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack arguments of [address, bytes, string] in compact-short format", encodingFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), @@ -209,7 +209,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack arguments of [string, address, bytes] in compact-long format", encodingFmt: memo.EncodingFmtCompactLong, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactLong, memo.ArgRevertAddress(argString), memo.ArgReceiver(argAddress), @@ -224,7 +224,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack empty byte array and string argument", encodingFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload([]byte{}), memo.ArgRevertAddress(""), @@ -255,7 +255,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack bytes argument if string is passed", encodingFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload(argBytes), ), @@ -267,7 +267,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack address argument if bytes is passed", encodingFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), ), @@ -279,7 +279,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack string argument if address is passed", encodingFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgRevertAddress(argString), ), @@ -301,7 +301,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { name: "unpacking should fail if not all data is consumed", encodingFmt: memo.EncodingFmtCompactShort, data: func() []byte { - data := sample.CompactPack(t, + data := sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index c3082417fa..4dec31e2ea 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -1,6 +1,8 @@ package memo import ( + "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" @@ -17,6 +19,11 @@ const ( bitPosRevertMessage uint8 = 3 // revertMessage ) +const ( + // MaskFlagReserved is the mask for reserved data flags + MaskFlagReserved = 0b11110000 +) + var _ Fields = (*FieldsV0)(nil) // FieldsV0 contains the data fields of the inbound memo V0 @@ -113,16 +120,22 @@ func (f *FieldsV0) unpackFields(codec Codec, data []byte) error { // add 'abortAddress' argument optionally var abortAddress common.Address - if zetamath.IsBitSet(dataFlags, bitPosRevertMessage) { + if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) { codec.AddArguments(ArgAbortAddress(&abortAddress)) } // add 'revertMessage' argument optionally - f.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosAbortAddress) + f.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosRevertMessage) if f.RevertOptions.CallOnRevert { codec.AddArguments(ArgRevertMessage(&f.RevertOptions.RevertMessage)) } + // all reserved flag bits must be zero + reserved := zetamath.GetBits(dataFlags, MaskFlagReserved) + if reserved != 0 { + return fmt.Errorf("reserved flag bits are not zero: %d", reserved) + } + // unpack the data (after flags) into codec arguments err := codec.UnpackArguments(data[1:]) if err != nil { diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index eac811d11d..966370007a 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -59,7 +59,7 @@ func Test_V0_Pack(t *testing.T) { }, }, expectedFlags: 0b00001111, // all fields are set - expectedData: sample.CompactPack(t, + expectedData: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -144,7 +144,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack all fields with compact encoding", encodingFormat: memo.EncodingFmtCompactShort, flags: 0b00001111, // all fields are set - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -178,7 +178,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack empty compact encoded payload if flag is not set", encodingFormat: memo.EncodingFmtCompactShort, flags: 0b00000001, // payload flag is set - data: sample.CompactPack(t, + data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload([]byte{})), // empty payload @@ -196,6 +196,15 @@ func Test_V0_Unpack(t *testing.T) { memo.ArgPayload(fBytes)), errMsg: "failed to unpack arguments", }, + { + name: "failed to unpack data if reserved flag is not zero", + encodingFormat: memo.EncodingFmtABI, + flags: 0b00100001, // payload flag and reserved bit5 are set + data: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes)), + errMsg: "reserved flag bits are not zero", + }, } for _, tc := range tests { diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index fa2b4570f3..1b4c20cf94 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -12,25 +12,31 @@ const ( // MemoIdentifier is the ASCII code of 'Z' (0x5A) MemoIdentifier byte = 0x5A - // HeaderSize is the size of the memo header - HeaderSize = 3 + // MemoHeaderSize is the size of the memo header: [identifier + ctrlByte1+ ctrlByte2 + dataFlags] + MemoHeaderSize = 4 - // MaskVersion is the mask for the version bits(5~7) - BitMaskVersion byte = 0b11100000 + // MemoBasicsSize is the size of the memo basics: [identifier + ctrlByte1 + ctrlByte2] + MemoBasicsSize = 3 - // BitMaskEncodingFormat is the mask for the encoding format bits(3~4) - BitMaskEncodingFormat byte = 0b00011000 + // MaskVersion is the mask for the version bits(upper 4 bits) + MaskVersion byte = 0b11110000 - // BitMaskOpCode is the mask for the operation code bits(0~2) - BitMaskOpCode byte = 0b00000111 + // MaskEncodingFormat is the mask for the encoding format bits(lower 4 bits) + MaskEncodingFormat byte = 0b00001111 + + // MaskOpCode is the mask for the operation code bits(upper 4 bits) + MaskOpCode byte = 0b11110000 + + // MaskCtrlReserved is the mask for reserved control bits (lower 4 bits) + MaskCtrlReserved byte = 0b00001111 ) -// Enum for non-EVM chain inbound operation code (3 bits) +// Enum for non-EVM chain inbound operation code (4 bits) const ( - InboundOpCodeDeposit uint8 = 0b000 // operation 'deposit' - InboundOpCodeDepositAndCall uint8 = 0b001 // operation 'deposit_and_call' - InboundOpCodeCall uint8 = 0b010 // operation 'call' - InboundOpCodeMax uint8 = 0b011 // operation max value + OpCodeDeposit uint8 = 0b0000 // operation 'deposit' + OpCodeDepositAndCall uint8 = 0b0001 // operation 'deposit_and_call' + OpCodeCall uint8 = 0b0010 // operation 'call' + OpCodeMax uint8 = 0b0011 // operation max value ) // InboundMemo represents the memo structure for non-EVM chains @@ -44,27 +50,35 @@ type InboundMemo struct { // OpCode is the inbound operation code OpCode uint8 + // Reserved is the reserved control bits + Reserved uint8 + // FieldsV0 contains the memo fields V0 // Note: add a FieldsV1 if major update is needed in the future FieldsV0 } // EncodeToBytes encodes a InboundMemo struct to raw bytes -func EncodeToBytes(memo *InboundMemo) ([]byte, error) { +func (m *InboundMemo) EncodeToBytes() ([]byte, error) { + // validate memo basics + err := m.ValidateBasics() + if err != nil { + return nil, err + } + // encode memo basics - basics := encodeBasics(memo) + basics := m.EncodeBasics() // encode memo fields based on version var data []byte - var err error - switch memo.Version { + switch m.Version { case 0: - data, err = memo.FieldsV0.Pack(memo.EncodingFormat) + data, err = m.FieldsV0.Pack(m.EncodingFormat) default: - return nil, fmt.Errorf("unsupported memo version: %d", memo.Version) + return nil, fmt.Errorf("invalid memo version: %d", m.Version) } if err != nil { - return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", memo.Version) + return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", m.Version) } return append(basics, data...), nil @@ -77,7 +91,13 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { memo := &InboundMemo{} // decode memo basics - err := decodeBasics(data, memo) + err := memo.DecodeBasics(data) + if err != nil { + return nil, err + } + + // validate memo basics + err = memo.ValidateBasics() if err != nil { return nil, err } @@ -85,9 +105,9 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { // decode memo fields based on version switch memo.Version { case 0: - err = memo.FieldsV0.Unpack(data, memo.EncodingFormat) + err = memo.FieldsV0.Unpack(data[MemoBasicsSize:], memo.EncodingFormat) default: - return nil, fmt.Errorf("unsupported memo version: %d", memo.Version) + return nil, fmt.Errorf("invalid memo version: %d", memo.Version) } if err != nil { return nil, errors.Wrapf(err, "failed to unpack memo fields version: %d", memo.Version) @@ -96,58 +116,68 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { return memo, nil } -// encodeBasics encodes the version, encoding format and operation code -func encodeBasics(memo *InboundMemo) []byte { - // 2 bytes: [identifier + ctrlByte] - basics := make([]byte, HeaderSize-1) +// Validate checks if the memo is valid +func (m *InboundMemo) ValidateBasics() error { + if m.EncodingFormat >= EncodingFmtMax { + return fmt.Errorf("invalid encoding format: %d", m.EncodingFormat) + } + + if m.OpCode >= OpCodeMax { + return fmt.Errorf("invalid operation code: %d", m.OpCode) + } + + // reserved control bits must be zero + if m.Reserved != 0 { + return fmt.Errorf("reserved control bits are not zero: %d", m.Reserved) + } + + return nil +} + +// EncodeBasics encodes theidentifier, version, encoding format and operation code +func (m *InboundMemo) EncodeBasics() []byte { + // basics: [identifier + ctrlByte1 + ctrlByte2] + basics := make([]byte, MemoBasicsSize) // set byte-0 as memo identifier basics[0] = MemoIdentifier - // set version # and encoding format - var ctrlByte byte - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskVersion, memo.Version) - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskEncodingFormat, memo.EncodingFormat) - ctrlByte = zetamath.SetBits(ctrlByte, BitMaskOpCode, memo.OpCode) + // set version #, encoding format + var ctrlByte1 byte + ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskVersion, m.Version) + ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskEncodingFormat, m.EncodingFormat) + basics[1] = ctrlByte1 - // set ctrlByte to byte-1 - basics[1] = ctrlByte + // set operation code, reserved bits + var ctrlByte2 byte + ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskOpCode, m.OpCode) + ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskCtrlReserved, m.Reserved) + basics[2] = ctrlByte2 return basics } -// decodeBasics decodes version, encoding format and operation code -func decodeBasics(data []byte, memo *InboundMemo) error { +// DecodeBasics decodes the identifier, version, encoding format and operation code +func (m *InboundMemo) DecodeBasics(data []byte) error { // memo data must be longer than the header size - if len(data) <= HeaderSize { - return errors.New("memo data too short") + if len(data) <= MemoHeaderSize { + return errors.New("memo is too short") } // byte-0 is the memo identifier if data[0] != MemoIdentifier { - return errors.New("memo identifier mismatch") - } - - // byte-1 is the control byte - ctrlByte := data[1] - - // extract version # - memo.Version = zetamath.GetBits(ctrlByte, BitMaskVersion) - if memo.Version != 0 { - return fmt.Errorf("unsupported memo version: %d", memo.Version) + return fmt.Errorf("invalid memo identifier: %d", data[0]) } - // extract encoding format - memo.EncodingFormat = zetamath.GetBits(ctrlByte, BitMaskEncodingFormat) - if memo.EncodingFormat >= EncodingFmtMax { - return fmt.Errorf("invalid encoding format: %d", memo.EncodingFormat) - } + // extract version #, encoding format + ctrlByte1 := data[1] + m.Version = zetamath.GetBits(ctrlByte1, MaskVersion) + m.EncodingFormat = zetamath.GetBits(ctrlByte1, MaskEncodingFormat) - // extract operation code - memo.OpCode = zetamath.GetBits(ctrlByte, BitMaskOpCode) - if memo.OpCode >= InboundOpCodeMax { - return fmt.Errorf("invalid operation code: %d", memo.OpCode) - } + // extract operation code, reserved bits + ctrlByte2 := data[2] + m.OpCode = zetamath.GetBits(ctrlByte2, MaskOpCode) + m.Reserved = zetamath.GetBits(ctrlByte2, MaskCtrlReserved) return nil } diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index add6944fdf..dbfc011472 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -1,10 +1,354 @@ package memo_test import ( - "encoding/hex" "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) -func Test_NewInboundMemo(t *testing.T) { - hex.EncodeToString([]byte("hello")) +func Test_EncodeToBytes(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + memo *memo.InboundMemo + expectedHead []byte + expectedData []byte + errMsg string + }{ + { + name: "encode memo with ABI encoding", + memo: &memo.InboundMemo{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + expectedHead: sample.MemoHead( + 0, + memo.EncodingFmtABI, + memo.OpCodeDepositAndCall, + 0, + 0b00001111, + ), // all fields are set + expectedData: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "encode memo with compact encoding", + memo: &memo.InboundMemo{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + expectedHead: sample.MemoHead( + 0, + memo.EncodingFmtCompactShort, + memo.OpCodeDepositAndCall, + 0, + 0b00001111, + ), // all fields are set + expectedData: sample.CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "failed to encode if basic validation fails", + memo: &memo.InboundMemo{ + EncodingFormat: memo.EncodingFmtMax, + }, + errMsg: "invalid encoding format", + }, + { + name: "failed to encode if version is invalid", + memo: &memo.InboundMemo{ + Version: 1, + }, + errMsg: "invalid memo version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.memo.EncodeToBytes() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, append(tt.expectedHead, tt.expectedData...), data) + }) + } +} + +func Test_DecodeFromBytes(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + head []byte + data []byte + expectedMemo memo.InboundMemo + errMsg string + }{ + { + name: "decode memo with ABI encoding", + head: sample.MemoHead( + 0, + memo.EncodingFmtABI, + memo.OpCodeDepositAndCall, + 0, + 0b00001111, + ), // all fields are set + data: sample.ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expectedMemo: memo.InboundMemo{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + }, + { + name: "decode memo with compact encoding", + head: sample.MemoHead( + 0, + memo.EncodingFmtCompactLong, + memo.OpCodeDepositAndCall, + 0, + 0b00001111, + ), // all fields are set + data: sample.CompactPack( + memo.EncodingFmtCompactLong, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expectedMemo: memo.InboundMemo{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactLong, + OpCode: memo.OpCodeDepositAndCall, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + }, + { + name: "failed to decode if basic validation fails", + head: sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeMax, 0, 0), + data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), + errMsg: "invalid operation code", + }, + { + name: "failed to decode if version is invalid", + head: sample.MemoHead(1, memo.EncodingFmtABI, memo.OpCodeDeposit, 0, 0), + data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), + errMsg: "invalid memo version", + }, + { + + name: "failed to decode compact encoded data with ABI encoding format", + head: sample.MemoHead( + 0, + memo.EncodingFmtABI, + memo.OpCodeDepositAndCall, + 0, + 0, + ), // head says ABI encoding + data: sample.CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + ), // but data is compact encoded + errMsg: "failed to unpack memo fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := append(tt.head, tt.data...) + memo, err := memo.DecodeFromBytes(data) + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, tt.expectedMemo, *memo) + }) + } +} + +func Test_ValidateBasics(t *testing.T) { + tests := []struct { + name string + memo *memo.InboundMemo + errMsg string + }{ + { + name: "valid memo", + memo: &memo.InboundMemo{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + }, + }, + { + name: "invalid encoding format", + memo: &memo.InboundMemo{ + EncodingFormat: memo.EncodingFmtMax, + }, + errMsg: "invalid encoding format", + }, + { + name: "invalid operation code", + memo: &memo.InboundMemo{ + OpCode: memo.OpCodeMax, + }, + errMsg: "invalid operation code", + }, + { + name: "reserved field is not zero", + memo: &memo.InboundMemo{ + Reserved: 1, + }, + errMsg: "reserved control bits are not zero", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.memo.ValidateBasics() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + }) + } +} + +func Test_EncodeBasics(t *testing.T) { + tests := []struct { + name string + memo *memo.InboundMemo + expected []byte + errMsg string + }{ + { + name: "it works", + memo: &memo.InboundMemo{ + Version: 1, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + Reserved: 15, + }, + expected: []byte{memo.MemoIdentifier, 0b00010000, 0b00101111}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + basics := tt.memo.EncodeBasics() + require.Equal(t, tt.expected, basics) + }) + } +} + +func Test_DecodeBasics(t *testing.T) { + tests := []struct { + name string + data []byte + expected memo.InboundMemo + errMsg string + }{ + { + name: "it works", + data: append(sample.MemoHead(1, memo.EncodingFmtABI, memo.OpCodeCall, 15, 0), []byte{0x01, 0x02}...), + expected: memo.InboundMemo{ + Version: 1, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + Reserved: 15, + }, + }, + { + name: "memo is too short", + data: []byte{0x01, 0x02, 0x03, 0x04}, + errMsg: "memo is too short", + }, + { + name: "invalid memo identifier", + data: []byte{'M', 0x02, 0x03, 0x04, 0x05}, + errMsg: "invalid memo identifier", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + memo := &memo.InboundMemo{} + err := memo.DecodeBasics(tt.data) + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, *memo) + }) + } } diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index f34354c392..137504f573 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -10,6 +10,17 @@ import ( "github.com/zeta-chain/node/pkg/memo" ) +// MemoHead is a helper function to create a memo head +// Note: all arguments are assume to be <= 0b1111 for simplicity. +func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { + head := make([]byte, memo.MemoHeaderSize) + head[0] = memo.MemoIdentifier + head[1] = version<<4 | encodingFmt + head[2] = opCode<<4 | reserved + head[3] = flags + return head +} + // ABIPack is a helper function to simulates the abi.Pack function. // Note: all arguments are assumed to be <= 32 bytes for simplicity. func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { @@ -65,7 +76,7 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { // CompactPack is a helper function to pack arguments into compact encoded data // Note: all arguments are assumed to be <= 65535 bytes for simplicity. -func CompactPack(_ *testing.T, encodingFmt uint8, args ...memo.CodecArg) []byte { +func CompactPack(encodingFmt uint8, args ...memo.CodecArg) []byte { var ( length int packedData []byte From 9272a2acabb2a688c0fbd9d2f1bf6d6fbded1941 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Wed, 9 Oct 2024 23:06:05 -0500 Subject: [PATCH 04/19] use separate file for memo header; add more unit tests --- pkg/memo/fields.go | 7 +- pkg/memo/fields_v0.go | 57 +++++++----- pkg/memo/fields_v0_test.go | 108 +++++++++++++++++++--- pkg/memo/header.go | 137 +++++++++++++++++++++++++++ pkg/memo/header_test.go | 155 +++++++++++++++++++++++++++++++ pkg/memo/memo.go | 152 +++++------------------------- pkg/memo/memo_test.go | 185 ++++++++++--------------------------- testutil/sample/memo.go | 4 +- 8 files changed, 500 insertions(+), 305 deletions(-) create mode 100644 pkg/memo/header.go create mode 100644 pkg/memo/header_test.go diff --git a/pkg/memo/fields.go b/pkg/memo/fields.go index 99118a1da9..06def4a6b5 100644 --- a/pkg/memo/fields.go +++ b/pkg/memo/fields.go @@ -3,8 +3,11 @@ package memo // Fields is the interface for memo fields type Fields interface { // Pack encodes the memo fields - Pack(encodingFormat uint8) ([]byte, error) + Pack(opCode, encodingFormat uint8) (byte, []byte, error) // Unpack decodes the memo fields - Unpack(data []byte, encodingFormat uint8) error + Unpack(opCode, encodingFormat, dataFlags uint8, data []byte) error + + // Validate checks if the fields are valid + Validate(opCode uint8) error } diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index 4dec31e2ea..a9f2d1b3b3 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -1,8 +1,6 @@ package memo import ( - "fmt" - "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" @@ -39,29 +37,55 @@ type FieldsV0 struct { } // Pack encodes the memo fields -func (f *FieldsV0) Pack(encodingFormat uint8) ([]byte, error) { +func (f *FieldsV0) Pack(opCode uint8, encodingFormat uint8) (byte, []byte, error) { + // validate fields + err := f.Validate(opCode) + if err != nil { + return 0, nil, err + } + codec, err := GetCodec(encodingFormat) if err != nil { - return nil, errors.Wrap(err, "unable to get codec") + return 0, nil, errors.Wrap(err, "unable to get codec") } return f.packFields(codec) } // Unpack decodes the memo fields -func (f *FieldsV0) Unpack(data []byte, encodingFormat uint8) error { +func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags byte, data []byte) error { codec, err := GetCodec(encodingFormat) if err != nil { return errors.Wrap(err, "unable to get codec") } - return f.unpackFields(codec, data) + err = f.unpackFields(codec, dataFlags, data) + if err != nil { + return err + } + + return f.Validate(opCode) +} + +// Validate checks if the fields are valid +func (f *FieldsV0) Validate(opCode uint8) error { + // check if receiver is empty + if crypto.IsEmptyAddress(f.Receiver) { + return errors.New("receiver address is empty") + } + + // ensure payload is not set for deposit operation + if opCode == OpCodeDeposit && len(f.Payload) > 0 { + return errors.New("payload is not allowed for deposit operation") + } + + return nil } // packFieldsV0 packs the memo fields for version 0 -func (f *FieldsV0) packFields(codec Codec) ([]byte, error) { +func (f *FieldsV0) packFields(codec Codec) (byte, []byte, error) { // create data flags byte - dataFlags := byte(0) + var dataFlags byte // add 'receiver' as the first argument codec.AddArguments(ArgReceiver(f.Receiver)) @@ -94,17 +118,14 @@ func (f *FieldsV0) packFields(codec Codec) ([]byte, error) { // pack the codec arguments into data data, err := codec.PackArguments() if err != nil { // never happens - return nil, errors.Wrap(err, "failed to pack arguments") + return 0, nil, errors.Wrap(err, "failed to pack arguments") } - return append([]byte{dataFlags}, data...), nil + return dataFlags, data, nil } // unpackFields unpacks the memo fields for version 0 -func (f *FieldsV0) unpackFields(codec Codec, data []byte) error { - // get data flags - dataFlags := data[0] - +func (f *FieldsV0) unpackFields(codec Codec, dataFlags byte, data []byte) error { // add 'receiver' as the first argument codec.AddArguments(ArgReceiver(&f.Receiver)) @@ -130,14 +151,8 @@ func (f *FieldsV0) unpackFields(codec Codec, data []byte) error { codec.AddArguments(ArgRevertMessage(&f.RevertOptions.RevertMessage)) } - // all reserved flag bits must be zero - reserved := zetamath.GetBits(dataFlags, MaskFlagReserved) - if reserved != 0 { - return fmt.Errorf("reserved flag bits are not zero: %d", reserved) - } - // unpack the data (after flags) into codec arguments - err := codec.UnpackArguments(data[1:]) + err := codec.UnpackArguments(data) if err != nil { return errors.Wrap(err, "failed to unpack arguments") } diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 966370007a..6d3b8c6706 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/testutil/sample" @@ -18,6 +19,7 @@ func Test_V0_Pack(t *testing.T) { tests := []struct { name string + opCode uint8 encodingFormat uint8 fields memo.FieldsV0 expectedFlags byte @@ -26,6 +28,7 @@ func Test_V0_Pack(t *testing.T) { }{ { name: "pack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, fields: memo.FieldsV0{ Receiver: fAddress, @@ -47,6 +50,7 @@ func Test_V0_Pack(t *testing.T) { }, { name: "pack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, fields: memo.FieldsV0{ Receiver: fAddress, @@ -69,6 +73,7 @@ func Test_V0_Pack(t *testing.T) { }, { name: "should not pack invalid abort address", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, fields: memo.FieldsV0{ Receiver: fAddress, @@ -80,7 +85,20 @@ func Test_V0_Pack(t *testing.T) { expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress)), }, { - name: "unable to get codec on invalid encoding format", + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, + encodingFormat: memo.EncodingFmtABI, + fields: memo.FieldsV0{ + Receiver: common.Address{}, + }, + errMsg: "receiver address is empty", + }, + { + name: "unable to get codec on invalid encoding format", + opCode: memo.OpCodeDepositAndCall, + fields: memo.FieldsV0{ + Receiver: fAddress, + }, encodingFormat: 0x0F, errMsg: "unable to get codec", }, @@ -89,18 +107,20 @@ func Test_V0_Pack(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // pack the fields - data, err := tc.fields.Pack(tc.encodingFormat) + flags, data, err := tc.fields.Pack(tc.opCode, tc.encodingFormat) // validate the error message if tc.errMsg != "" { require.ErrorContains(t, err, tc.errMsg) + require.Zero(t, flags) + require.Nil(t, data) return } // compare the fields require.NoError(t, err) - require.Equal(t, tc.expectedFlags, data[0]) - require.True(t, bytes.Equal(tc.expectedData, data[1:])) + require.Equal(t, tc.expectedFlags, flags) + require.True(t, bytes.Equal(tc.expectedData, data)) }) } } @@ -113,6 +133,7 @@ func Test_V0_Unpack(t *testing.T) { tests := []struct { name string + opCode uint8 encodingFormat uint8 flags byte data []byte @@ -121,6 +142,7 @@ func Test_V0_Unpack(t *testing.T) { }{ { name: "unpack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, flags: 0b00001111, // all fields are set data: sample.ABIPack(t, @@ -142,6 +164,7 @@ func Test_V0_Unpack(t *testing.T) { }, { name: "unpack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, flags: 0b00001111, // all fields are set data: sample.CompactPack( @@ -164,6 +187,7 @@ func Test_V0_Unpack(t *testing.T) { }, { name: "unpack empty ABI encoded payload if flag is set", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, flags: 0b00000001, // payload flag is set data: sample.ABIPack(t, @@ -176,6 +200,7 @@ func Test_V0_Unpack(t *testing.T) { }, { name: "unpack empty compact encoded payload if flag is not set", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, flags: 0b00000001, // payload flag is set data: sample.CompactPack( @@ -189,6 +214,7 @@ func Test_V0_Unpack(t *testing.T) { }, { name: "failed to unpack ABI encoded data with compact encoding format", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, flags: 0b00000001, data: sample.ABIPack(t, @@ -197,24 +223,22 @@ func Test_V0_Unpack(t *testing.T) { errMsg: "failed to unpack arguments", }, { - name: "failed to unpack data if reserved flag is not zero", + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, - flags: 0b00100001, // payload flag and reserved bit5 are set + flags: 0b00000001, data: sample.ABIPack(t, - memo.ArgReceiver(fAddress), + memo.ArgReceiver(common.Address{}), memo.ArgPayload(fBytes)), - errMsg: "reserved flag bits are not zero", + errMsg: "receiver address is empty", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // attach data flags - tc.data = append([]byte{tc.flags}, tc.data...) - // unpack the fields fields := memo.FieldsV0{} - err := fields.Unpack(tc.data, tc.encodingFormat) + err := fields.Unpack(tc.opCode, tc.encodingFormat, tc.flags, tc.data) // validate the error message if tc.errMsg != "" { @@ -228,3 +252,63 @@ func Test_V0_Unpack(t *testing.T) { }) } } + +func Test_V0_Validate(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + opCode uint8 + fields memo.FieldsV0 + errMsg string + }{ + { + name: "valid fields", + opCode: memo.OpCodeDepositAndCall, + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "invalid receiver address", + opCode: memo.OpCodeCall, + fields: memo.FieldsV0{ + Receiver: common.Address{}, // empty receiver address + }, + errMsg: "receiver address is empty", + }, + { + name: "payload is not allowed when opCode is deposit", + opCode: memo.OpCodeDeposit, + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, // payload is mistakenly set + }, + errMsg: "payload is not allowed for deposit operation", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // validate the fields + err := tc.fields.Validate(tc.opCode) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + require.NoError(t, err) + }) + } +} diff --git a/pkg/memo/header.go b/pkg/memo/header.go new file mode 100644 index 0000000000..b2638f38b8 --- /dev/null +++ b/pkg/memo/header.go @@ -0,0 +1,137 @@ +package memo + +import ( + "fmt" + + "github.com/pkg/errors" + + zetamath "github.com/zeta-chain/node/pkg/math" +) + +const ( + // Identifier is the ASCII code of 'Z' (0x5A) + Identifier byte = 0x5A + + // HeaderSize is the size of the memo header: [identifier + ctrlByte1+ ctrlByte2 + dataFlags] + HeaderSize = 4 + + // MaskVersion is the mask for the version bits(upper 4 bits) + MaskVersion byte = 0b11110000 + + // MaskEncodingFormat is the mask for the encoding format bits(lower 4 bits) + MaskEncodingFormat byte = 0b00001111 + + // MaskOpCode is the mask for the operation code bits(upper 4 bits) + MaskOpCode byte = 0b11110000 + + // MaskCtrlReserved is the mask for reserved control bits (lower 4 bits) + MaskCtrlReserved byte = 0b00001111 +) + +// Enum for non-EVM chain inbound operation code (4 bits) +const ( + OpCodeDeposit uint8 = 0b0000 // operation 'deposit' + OpCodeDepositAndCall uint8 = 0b0001 // operation 'deposit_and_call' + OpCodeCall uint8 = 0b0010 // operation 'call' + OpCodeMax uint8 = 0b0011 // operation max value +) + +// Header represent the memo header +type Header struct { + // Version is the memo Version + Version uint8 + + // EncodingFormat is the memo encoding format + EncodingFormat uint8 + + // OpCode is the inbound operation code + OpCode uint8 + + // Reserved is the reserved control bits + Reserved uint8 + + // DataFlags is the data flags + DataFlags uint8 +} + +// EncodeToBytes encodes the memo header to raw bytes +func (h *Header) EncodeToBytes() ([]byte, error) { + // validate header + if err := h.Validate(); err != nil { + return nil, err + } + + // create buffer for the header + data := make([]byte, HeaderSize) + + // set byte-0 as memo identifier + data[0] = Identifier + + // set version #, encoding format + var ctrlByte1 byte + ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskVersion, h.Version) + ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskEncodingFormat, h.EncodingFormat) + data[1] = ctrlByte1 + + // set operation code, reserved bits + var ctrlByte2 byte + ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskOpCode, h.OpCode) + ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskCtrlReserved, h.Reserved) + data[2] = ctrlByte2 + + // set data flags + data[3] = h.DataFlags + + return data, nil +} + +// DecodeFromBytes decodes the memo header from the given data +func (h *Header) DecodeFromBytes(data []byte) error { + // memo data must be longer than the header size + if len(data) <= HeaderSize { + return errors.New("memo is too short") + } + + // byte-0 is the memo identifier + if data[0] != Identifier { + return fmt.Errorf("invalid memo identifier: %d", data[0]) + } + + // extract version #, encoding format + ctrlByte1 := data[1] + h.Version = zetamath.GetBits(ctrlByte1, MaskVersion) + h.EncodingFormat = zetamath.GetBits(ctrlByte1, MaskEncodingFormat) + + // extract operation code, reserved bits + ctrlByte2 := data[2] + h.OpCode = zetamath.GetBits(ctrlByte2, MaskOpCode) + h.Reserved = zetamath.GetBits(ctrlByte2, MaskCtrlReserved) + + // extract data flags + h.DataFlags = data[3] + + // validate header + return h.Validate() +} + +// Validate checks if the memo header is valid +func (h *Header) Validate() error { + if h.Version != 0 { + return fmt.Errorf("invalid memo version: %d", h.Version) + } + + if h.EncodingFormat >= EncodingFmtMax { + return fmt.Errorf("invalid encoding format: %d", h.EncodingFormat) + } + + if h.OpCode >= OpCodeMax { + return fmt.Errorf("invalid operation code: %d", h.OpCode) + } + + // reserved control bits must be zero + if h.Reserved != 0 { + return fmt.Errorf("reserved control bits are not zero: %d", h.Reserved) + } + + return nil +} diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go new file mode 100644 index 0000000000..0c61608c28 --- /dev/null +++ b/pkg/memo/header_test.go @@ -0,0 +1,155 @@ +package memo_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" +) + +func Test_Header_EncodeToBytes(t *testing.T) { + tests := []struct { + name string + header memo.Header + expected []byte + errMsg string + }{ + { + name: "it works", + header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + DataFlags: 0b00001111, + }, + expected: []byte{memo.Identifier, 0b00000000, 0b00100000, 0b00001111}, + }, + { + name: "header validation failed", + header: memo.Header{ + Version: 1, // invalid version + }, + errMsg: "invalid memo version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header, err := tt.header.EncodeToBytes() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + require.Nil(t, header) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, header) + }) + } +} + +func Test_Header_DecodeFromBytes(t *testing.T) { + tests := []struct { + name string + data []byte + expected memo.Header + errMsg string + }{ + { + name: "it works", + data: append(sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeCall, 0, 0), []byte{0x01, 0x02}...), + expected: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + Reserved: 0, + }, + }, + { + name: "memo is too short", + data: []byte{0x01, 0x02, 0x03, 0x04}, + errMsg: "memo is too short", + }, + { + name: "invalid memo identifier", + data: []byte{'M', 0x02, 0x03, 0x04, 0x05}, + errMsg: "invalid memo identifier", + }, + { + name: "header validation failed", + data: append( + sample.MemoHead(0, memo.EncodingFmtMax, memo.OpCodeCall, 0, 0), + []byte{0x01, 0x02}...), // invalid encoding format + errMsg: "invalid encoding format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + memo := &memo.Header{} + err := memo.DecodeFromBytes(tt.data) + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, *memo) + }) + } +} + +func Test_Header_Validate(t *testing.T) { + tests := []struct { + name string + header memo.Header + errMsg string + }{ + { + name: "valid header", + header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + }, + }, + { + name: "invalid version", + header: memo.Header{ + Version: 1, + }, + errMsg: "invalid memo version", + }, + { + name: "invalid encoding format", + header: memo.Header{ + EncodingFormat: memo.EncodingFmtMax, + }, + errMsg: "invalid encoding format", + }, + { + name: "invalid operation code", + header: memo.Header{ + OpCode: memo.OpCodeMax, + }, + errMsg: "invalid operation code", + }, + { + name: "reserved field is not zero", + header: memo.Header{ + Reserved: 1, + }, + errMsg: "reserved control bits are not zero", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.header.Validate() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + }) + } +} diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index 1b4c20cf94..2844bf7755 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -4,76 +4,35 @@ import ( "fmt" "github.com/pkg/errors" - - zetamath "github.com/zeta-chain/node/pkg/math" -) - -const ( - // MemoIdentifier is the ASCII code of 'Z' (0x5A) - MemoIdentifier byte = 0x5A - - // MemoHeaderSize is the size of the memo header: [identifier + ctrlByte1+ ctrlByte2 + dataFlags] - MemoHeaderSize = 4 - - // MemoBasicsSize is the size of the memo basics: [identifier + ctrlByte1 + ctrlByte2] - MemoBasicsSize = 3 - - // MaskVersion is the mask for the version bits(upper 4 bits) - MaskVersion byte = 0b11110000 - - // MaskEncodingFormat is the mask for the encoding format bits(lower 4 bits) - MaskEncodingFormat byte = 0b00001111 - - // MaskOpCode is the mask for the operation code bits(upper 4 bits) - MaskOpCode byte = 0b11110000 - - // MaskCtrlReserved is the mask for reserved control bits (lower 4 bits) - MaskCtrlReserved byte = 0b00001111 -) - -// Enum for non-EVM chain inbound operation code (4 bits) -const ( - OpCodeDeposit uint8 = 0b0000 // operation 'deposit' - OpCodeDepositAndCall uint8 = 0b0001 // operation 'deposit_and_call' - OpCodeCall uint8 = 0b0010 // operation 'call' - OpCodeMax uint8 = 0b0011 // operation max value ) -// InboundMemo represents the memo structure for non-EVM chains +// InboundMemo represents the inbound memo structure for non-EVM chains type InboundMemo struct { - // Version is the memo Version - Version uint8 - - // EncodingFormat is the memo encoding format - EncodingFormat uint8 - - // OpCode is the inbound operation code - OpCode uint8 - - // Reserved is the reserved control bits - Reserved uint8 + // Header contains the memo header + Header // FieldsV0 contains the memo fields V0 - // Note: add a FieldsV1 if major update is needed in the future + // Note: add a FieldsV1 if breaking change is needed in the future FieldsV0 } // EncodeToBytes encodes a InboundMemo struct to raw bytes +// +// Note: +// - Any provided 'DataFlags' is ignored as they are calculated based on the fields set in the memo. +// - The 'RevertGasLimit' is not used for now for non-EVM chains. func (m *InboundMemo) EncodeToBytes() ([]byte, error) { - // validate memo basics - err := m.ValidateBasics() + // encode header + header, err := m.Header.EncodeToBytes() if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to encode memo header") } - // encode memo basics - basics := m.EncodeBasics() - - // encode memo fields based on version + // encode fields based on version var data []byte switch m.Version { case 0: - data, err = m.FieldsV0.Pack(m.EncodingFormat) + m.DataFlags, data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFormat) default: return nil, fmt.Errorf("invalid memo version: %d", m.Version) } @@ -81,7 +40,10 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", m.Version) } - return append(basics, data...), nil + // update data flags with the calculated value + header[3] = m.DataFlags + + return append(header, data...), nil } // DecodeFromBytes decodes a InboundMemo struct from raw bytes @@ -90,22 +52,16 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { func DecodeFromBytes(data []byte) (*InboundMemo, error) { memo := &InboundMemo{} - // decode memo basics - err := memo.DecodeBasics(data) - if err != nil { - return nil, err - } - - // validate memo basics - err = memo.ValidateBasics() + // decode header + err := memo.Header.DecodeFromBytes(data) if err != nil { return nil, err } - // decode memo fields based on version + // decode fields based on version switch memo.Version { case 0: - err = memo.FieldsV0.Unpack(data[MemoBasicsSize:], memo.EncodingFormat) + err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFormat, memo.DataFlags, data[HeaderSize:]) default: return nil, fmt.Errorf("invalid memo version: %d", memo.Version) } @@ -115,69 +71,3 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { return memo, nil } - -// Validate checks if the memo is valid -func (m *InboundMemo) ValidateBasics() error { - if m.EncodingFormat >= EncodingFmtMax { - return fmt.Errorf("invalid encoding format: %d", m.EncodingFormat) - } - - if m.OpCode >= OpCodeMax { - return fmt.Errorf("invalid operation code: %d", m.OpCode) - } - - // reserved control bits must be zero - if m.Reserved != 0 { - return fmt.Errorf("reserved control bits are not zero: %d", m.Reserved) - } - - return nil -} - -// EncodeBasics encodes theidentifier, version, encoding format and operation code -func (m *InboundMemo) EncodeBasics() []byte { - // basics: [identifier + ctrlByte1 + ctrlByte2] - basics := make([]byte, MemoBasicsSize) - - // set byte-0 as memo identifier - basics[0] = MemoIdentifier - - // set version #, encoding format - var ctrlByte1 byte - ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskVersion, m.Version) - ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskEncodingFormat, m.EncodingFormat) - basics[1] = ctrlByte1 - - // set operation code, reserved bits - var ctrlByte2 byte - ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskOpCode, m.OpCode) - ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskCtrlReserved, m.Reserved) - basics[2] = ctrlByte2 - - return basics -} - -// DecodeBasics decodes the identifier, version, encoding format and operation code -func (m *InboundMemo) DecodeBasics(data []byte) error { - // memo data must be longer than the header size - if len(data) <= MemoHeaderSize { - return errors.New("memo is too short") - } - - // byte-0 is the memo identifier - if data[0] != MemoIdentifier { - return fmt.Errorf("invalid memo identifier: %d", data[0]) - } - - // extract version #, encoding format - ctrlByte1 := data[1] - m.Version = zetamath.GetBits(ctrlByte1, MaskVersion) - m.EncodingFormat = zetamath.GetBits(ctrlByte1, MaskEncodingFormat) - - // extract operation code, reserved bits - ctrlByte2 := data[2] - m.OpCode = zetamath.GetBits(ctrlByte2, MaskOpCode) - m.Reserved = zetamath.GetBits(ctrlByte2, MaskCtrlReserved) - - return nil -} diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index dbfc011472..09726bdea2 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -9,7 +9,7 @@ import ( crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) -func Test_EncodeToBytes(t *testing.T) { +func Test_Memo_EncodeToBytes(t *testing.T) { // create sample fields fAddress := sample.EthAddress() fBytes := []byte("here_s_some_bytes_field") @@ -25,9 +25,11 @@ func Test_EncodeToBytes(t *testing.T) { { name: "encode memo with ABI encoding", memo: &memo.InboundMemo{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeDepositAndCall, + Header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -56,9 +58,11 @@ func Test_EncodeToBytes(t *testing.T) { { name: "encode memo with compact encoding", memo: &memo.InboundMemo{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactShort, - OpCode: memo.OpCodeDepositAndCall, + Header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -86,19 +90,38 @@ func Test_EncodeToBytes(t *testing.T) { memo.ArgRevertMessage(fBytes)), }, { - name: "failed to encode if basic validation fails", + name: "failed to encode memo header", memo: &memo.InboundMemo{ - EncodingFormat: memo.EncodingFmtMax, + Header: memo.Header{ + OpCode: memo.OpCodeMax, // invalid operation code + }, }, - errMsg: "invalid encoding format", + errMsg: "failed to encode memo header", }, { name: "failed to encode if version is invalid", memo: &memo.InboundMemo{ - Version: 1, + Header: memo.Header{ + Version: 1, + }, }, errMsg: "invalid memo version", }, + { + name: "failed to pack memo fields", + memo: &memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeDeposit, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, // payload is not allowed for deposit + }, + }, + errMsg: "failed to pack memo fields version: 0", + }, } for _, tt := range tests { @@ -106,6 +129,7 @@ func Test_EncodeToBytes(t *testing.T) { data, err := tt.memo.EncodeToBytes() if tt.errMsg != "" { require.ErrorContains(t, err, tt.errMsg) + require.Nil(t, data) return } require.NoError(t, err) @@ -114,7 +138,7 @@ func Test_EncodeToBytes(t *testing.T) { } } -func Test_DecodeFromBytes(t *testing.T) { +func Test_Memo_DecodeFromBytes(t *testing.T) { // create sample fields fAddress := sample.EthAddress() fBytes := []byte("here_s_some_bytes_field") @@ -143,9 +167,12 @@ func Test_DecodeFromBytes(t *testing.T) { memo.ArgAbortAddress(fAddress), memo.ArgRevertMessage(fBytes)), expectedMemo: memo.InboundMemo{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeDepositAndCall, + Header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00001111, + }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -175,9 +202,12 @@ func Test_DecodeFromBytes(t *testing.T) { memo.ArgAbortAddress(fAddress), memo.ArgRevertMessage(fBytes)), expectedMemo: memo.InboundMemo{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactLong, - OpCode: memo.OpCodeDepositAndCall, + Header: memo.Header{ + Version: 0, + EncodingFormat: memo.EncodingFmtCompactLong, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00001111, + }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -233,122 +263,3 @@ func Test_DecodeFromBytes(t *testing.T) { }) } } - -func Test_ValidateBasics(t *testing.T) { - tests := []struct { - name string - memo *memo.InboundMemo - errMsg string - }{ - { - name: "valid memo", - memo: &memo.InboundMemo{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactShort, - OpCode: memo.OpCodeDepositAndCall, - }, - }, - { - name: "invalid encoding format", - memo: &memo.InboundMemo{ - EncodingFormat: memo.EncodingFmtMax, - }, - errMsg: "invalid encoding format", - }, - { - name: "invalid operation code", - memo: &memo.InboundMemo{ - OpCode: memo.OpCodeMax, - }, - errMsg: "invalid operation code", - }, - { - name: "reserved field is not zero", - memo: &memo.InboundMemo{ - Reserved: 1, - }, - errMsg: "reserved control bits are not zero", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.memo.ValidateBasics() - if tt.errMsg != "" { - require.ErrorContains(t, err, tt.errMsg) - return - } - require.NoError(t, err) - }) - } -} - -func Test_EncodeBasics(t *testing.T) { - tests := []struct { - name string - memo *memo.InboundMemo - expected []byte - errMsg string - }{ - { - name: "it works", - memo: &memo.InboundMemo{ - Version: 1, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeCall, - Reserved: 15, - }, - expected: []byte{memo.MemoIdentifier, 0b00010000, 0b00101111}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - basics := tt.memo.EncodeBasics() - require.Equal(t, tt.expected, basics) - }) - } -} - -func Test_DecodeBasics(t *testing.T) { - tests := []struct { - name string - data []byte - expected memo.InboundMemo - errMsg string - }{ - { - name: "it works", - data: append(sample.MemoHead(1, memo.EncodingFmtABI, memo.OpCodeCall, 15, 0), []byte{0x01, 0x02}...), - expected: memo.InboundMemo{ - Version: 1, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeCall, - Reserved: 15, - }, - }, - { - name: "memo is too short", - data: []byte{0x01, 0x02, 0x03, 0x04}, - errMsg: "memo is too short", - }, - { - name: "invalid memo identifier", - data: []byte{'M', 0x02, 0x03, 0x04, 0x05}, - errMsg: "invalid memo identifier", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - memo := &memo.InboundMemo{} - err := memo.DecodeBasics(tt.data) - if tt.errMsg != "" { - require.ErrorContains(t, err, tt.errMsg) - return - } - require.NoError(t, err) - require.Equal(t, tt.expected, *memo) - }) - } -} diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index 137504f573..4bee3c24fd 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -13,8 +13,8 @@ import ( // MemoHead is a helper function to create a memo head // Note: all arguments are assume to be <= 0b1111 for simplicity. func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { - head := make([]byte, memo.MemoHeaderSize) - head[0] = memo.MemoIdentifier + head := make([]byte, memo.HeaderSize) + head[0] = memo.Identifier head[1] = version<<4 | encodingFmt head[2] = opCode<<4 | reserved head[3] = flags From a1f3da8bbe1778675c62326d196936d9d47e43f9 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 11:02:40 -0500 Subject: [PATCH 05/19] a few renaming and wrapped err message --- pkg/memo/fields_v0.go | 4 ++-- pkg/memo/memo.go | 10 +++++----- pkg/memo/memo_test.go | 7 +++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index a9f2d1b3b3..b3b9a5729b 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -18,8 +18,8 @@ const ( ) const ( - // MaskFlagReserved is the mask for reserved data flags - MaskFlagReserved = 0b11110000 + // MaskFlagsReserved is the mask for reserved data flags + MaskFlagsReserved = 0b11110000 ) var _ Fields = (*FieldsV0)(nil) diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index 2844bf7755..7aa201690b 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -22,8 +22,8 @@ type InboundMemo struct { // - Any provided 'DataFlags' is ignored as they are calculated based on the fields set in the memo. // - The 'RevertGasLimit' is not used for now for non-EVM chains. func (m *InboundMemo) EncodeToBytes() ([]byte, error) { - // encode header - header, err := m.Header.EncodeToBytes() + // encode head + head, err := m.Header.EncodeToBytes() if err != nil { return nil, errors.Wrap(err, "failed to encode memo header") } @@ -41,9 +41,9 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { } // update data flags with the calculated value - header[3] = m.DataFlags + head[3] = m.DataFlags - return append(header, data...), nil + return append(head, data...), nil } // DecodeFromBytes decodes a InboundMemo struct from raw bytes @@ -55,7 +55,7 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { // decode header err := memo.Header.DecodeFromBytes(data) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to decode memo header") } // decode fields based on version diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 09726bdea2..97629005ae 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -221,10 +221,10 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }, }, { - name: "failed to decode if basic validation fails", + name: "failed to decode memo header", head: sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeMax, 0, 0), data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), - errMsg: "invalid operation code", + errMsg: "failed to decode memo header", }, { name: "failed to decode if version is invalid", @@ -233,7 +233,6 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { errMsg: "invalid memo version", }, { - name: "failed to decode compact encoded data with ABI encoding format", head: sample.MemoHead( 0, @@ -241,7 +240,7 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { memo.OpCodeDepositAndCall, 0, 0, - ), // head says ABI encoding + ), // header says ABI encoding data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), From cb72262915074f32bbbb907bfcefcebee73bc4de Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 11:35:49 -0500 Subject: [PATCH 06/19] add extra good-to-have check for memo fields validation --- pkg/memo/fields_v0.go | 9 ++++++++- pkg/memo/fields_v0_test.go | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index b3b9a5729b..40d3977445 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -74,11 +74,18 @@ func (f *FieldsV0) Validate(opCode uint8) error { return errors.New("receiver address is empty") } - // ensure payload is not set for deposit operation + // payload is not allowed for deposit operation if opCode == OpCodeDeposit && len(f.Payload) > 0 { return errors.New("payload is not allowed for deposit operation") } + // revert message is not allowed when CallOnRevert is false + // 1. it's a good-to-have check to make the fields semantically correct. + // 2. unpacking won't hit this error as the codec will catch it earlier. + if !f.RevertOptions.CallOnRevert && len(f.RevertOptions.RevertMessage) > 0 { + return errors.New("revert message is not allowed when CallOnRevert is false") + } + return nil } diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 6d3b8c6706..84f9ce4939 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -296,6 +296,19 @@ func Test_V0_Validate(t *testing.T) { }, errMsg: "payload is not allowed for deposit operation", }, + { + name: "revert message is not allowed when CallOnRevert is false", + opCode: memo.OpCodeDeposit, + fields: memo.FieldsV0{ + Receiver: fAddress, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: false, // CallOnRevert is false + RevertMessage: []byte("revert message"), // revert message is mistakenly set + }, + }, + errMsg: "revert message is not allowed when CallOnRevert is false", + }, } for _, tc := range tests { From ed34ac4407425b5f13322e21640660a52069a582 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 11:51:20 -0500 Subject: [PATCH 07/19] add changelog entry --- changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog.md b/changelog.md index e25f46e129..b57f84ca7f 100644 --- a/changelog.md +++ b/changelog.md @@ -18,6 +18,7 @@ * [2919](https://github.com/zeta-chain/node/pull/2919) - add inbound sender to revert context * [2957](https://github.com/zeta-chain/node/pull/2957) - enable Bitcoin inscription support on testnet * [2896](https://github.com/zeta-chain/node/pull/2896) - add TON inbound observation +* [2987](https://github.com/zeta-chain/node/pull/2987) - add non-EVM standard inbound memo package ### Refactor From 416b8d4f189bd8393153acadefe9e5c5c85bc836 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 12:20:28 -0500 Subject: [PATCH 08/19] fix nosec error --- pkg/memo/codec_compact.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go index def5213a01..f295c0e0ab 100644 --- a/pkg/memo/codec_compact.go +++ b/pkg/memo/codec_compact.go @@ -122,6 +122,7 @@ func (c *CodecCompact) packLength(length int) ([]byte, error) { if length > math.MaxUint16 { return nil, fmt.Errorf("data length %d exceeds %d bytes", length, math.MaxUint16) } + // #nosec G115 range checked binary.LittleEndian.PutUint16(data, uint16(length)) } return data, nil From c48d999ddc0316f9f5355ee6965d04d97d5a8bf3 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 12:33:02 -0500 Subject: [PATCH 09/19] add two more unit tests for missed lines --- pkg/memo/codec_compact_test.go | 9 +++++++++ pkg/memo/fields_v0_test.go | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go index 90f9127c08..046acfea3e 100644 --- a/pkg/memo/codec_compact_test.go +++ b/pkg/memo/codec_compact_test.go @@ -234,6 +234,15 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { memo.ArgRevertAddress(""), }, }, + { + name: "failed to unpack address if data length < 20 bytes", + encodingFmt: memo.EncodingFmtCompactShort, + data: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + }, + errMsg: "expected address, got 5 bytes", + }, { name: "failed to unpack string if data length < 1 byte", encodingFmt: memo.EncodingFmtCompactShort, diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 84f9ce4939..afcdce9842 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -212,6 +212,14 @@ func Test_V0_Unpack(t *testing.T) { Payload: []byte{}, }, }, + { + name: "unable to get codec on invalid encoding format", + opCode: memo.OpCodeDepositAndCall, + encodingFormat: 0x0F, + flags: 0b00000001, + data: []byte{}, + errMsg: "unable to get codec", + }, { name: "failed to unpack ABI encoded data with compact encoding format", opCode: memo.OpCodeDepositAndCall, From 9f92e23aefc27cf334a897247885885831908c33 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 12:45:12 -0500 Subject: [PATCH 10/19] remove redundant dependency github.com/test-go/testify v1.1.4 --- go.mod | 1 - 1 file changed, 1 deletion(-) diff --git a/go.mod b/go.mod index ba1bea9b91..7625fc2047 100644 --- a/go.mod +++ b/go.mod @@ -336,7 +336,6 @@ require ( require ( github.com/bnb-chain/tss-lib v1.5.0 github.com/showa-93/go-mask v0.6.2 - github.com/test-go/testify v1.1.4 github.com/tonkeeper/tongo v1.9.3 ) From 6021f984a1d40cfb637228aec78b163cbfc71d72 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 13:22:29 -0500 Subject: [PATCH 11/19] enhance codec error message --- pkg/memo/codec.go | 12 +++++------- pkg/memo/codec_test.go | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pkg/memo/codec.go b/pkg/memo/codec.go index 94539108a3..06f1d71673 100644 --- a/pkg/memo/codec.go +++ b/pkg/memo/codec.go @@ -2,23 +2,21 @@ package memo import ( "fmt" - - "github.com/pkg/errors" ) // Enum for non-EVM chain memo encoding format (2 bits) const ( // EncodingFmtABI represents ABI encoding format - EncodingFmtABI uint8 = 0b00 + EncodingFmtABI uint8 = 0b0000 // EncodingFmtCompactShort represents 'compact short' encoding format - EncodingFmtCompactShort uint8 = 0b01 + EncodingFmtCompactShort uint8 = 0b0001 // EncodingFmtCompactLong represents 'compact long' encoding format - EncodingFmtCompactLong uint8 = 0b10 + EncodingFmtCompactLong uint8 = 0b0010 // EncodingFmtMax is the max value of encoding format - EncodingFmtMax uint8 = 0b11 + EncodingFmtMax uint8 = 0b0011 ) // Enum for length of bytes used to encode compact data @@ -59,6 +57,6 @@ func GetCodec(encodingFormat uint8) (Codec, error) { case EncodingFmtCompactShort, EncodingFmtCompactLong: return NewCodecCompact(encodingFormat) default: - return nil, errors.New("unsupported encoding format") + return nil, fmt.Errorf("invalid encoding format %d", encodingFormat) } } diff --git a/pkg/memo/codec_test.go b/pkg/memo/codec_test.go index 27eb73d977..182414eeeb 100644 --- a/pkg/memo/codec_test.go +++ b/pkg/memo/codec_test.go @@ -71,8 +71,8 @@ func Test_GetCodec(t *testing.T) { }, { name: "should fail to get codec", - encodingFmt: 0b11, - errMsg: "unsupported", + encodingFmt: 0b0011, + errMsg: "invalid encoding format", }, } From fc133773ecee99bc54819dbe6a97553105176b35 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 13:47:54 -0500 Subject: [PATCH 12/19] a few renaming and move test constant to testutil pkg --- pkg/memo/codec.go | 4 ++-- pkg/memo/codec_abi.go | 8 +++++--- pkg/memo/header.go | 6 +++--- pkg/memo/header_test.go | 6 +++--- pkg/memo/memo_test.go | 4 ++-- testutil/sample/memo.go | 17 +++++++++++------ 6 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pkg/memo/codec.go b/pkg/memo/codec.go index 06f1d71673..d1f592f4b3 100644 --- a/pkg/memo/codec.go +++ b/pkg/memo/codec.go @@ -15,8 +15,8 @@ const ( // EncodingFmtCompactLong represents 'compact long' encoding format EncodingFmtCompactLong uint8 = 0b0010 - // EncodingFmtMax is the max value of encoding format - EncodingFmtMax uint8 = 0b0011 + // EncodingFmtInvalid represents invalid encoding format + EncodingFmtInvalid uint8 = 0b0011 ) // Enum for length of bytes used to encode compact data diff --git a/pkg/memo/codec_abi.go b/pkg/memo/codec_abi.go index 9517722006..708e1a6398 100644 --- a/pkg/memo/codec_abi.go +++ b/pkg/memo/codec_abi.go @@ -9,9 +9,6 @@ import ( ) const ( - // ABIAlignment is the number of bytes used to align the ABI encoded data - ABIAlignment = 32 - // selectorLength is the length of the selector in bytes selectorLength = 4 @@ -64,6 +61,11 @@ func (c *CodecABI) PackArguments() ([]byte, error) { return nil, errors.Wrap(err, "failed to pack ABI arguments") } + // this never happens + if len(data) < selectorLength { + return nil, errors.New("packed data less than selector length") + } + return data[selectorLength:], nil } diff --git a/pkg/memo/header.go b/pkg/memo/header.go index b2638f38b8..87aa683b88 100644 --- a/pkg/memo/header.go +++ b/pkg/memo/header.go @@ -33,7 +33,7 @@ const ( OpCodeDeposit uint8 = 0b0000 // operation 'deposit' OpCodeDepositAndCall uint8 = 0b0001 // operation 'deposit_and_call' OpCodeCall uint8 = 0b0010 // operation 'call' - OpCodeMax uint8 = 0b0011 // operation max value + OpCodeInvalid uint8 = 0b0011 // invalid operation code ) // Header represent the memo header @@ -120,11 +120,11 @@ func (h *Header) Validate() error { return fmt.Errorf("invalid memo version: %d", h.Version) } - if h.EncodingFormat >= EncodingFmtMax { + if h.EncodingFormat >= EncodingFmtInvalid { return fmt.Errorf("invalid encoding format: %d", h.EncodingFormat) } - if h.OpCode >= OpCodeMax { + if h.OpCode >= OpCodeInvalid { return fmt.Errorf("invalid operation code: %d", h.OpCode) } diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go index 0c61608c28..c22c2c765a 100644 --- a/pkg/memo/header_test.go +++ b/pkg/memo/header_test.go @@ -78,7 +78,7 @@ func Test_Header_DecodeFromBytes(t *testing.T) { { name: "header validation failed", data: append( - sample.MemoHead(0, memo.EncodingFmtMax, memo.OpCodeCall, 0, 0), + sample.MemoHead(0, memo.EncodingFmtInvalid, memo.OpCodeCall, 0, 0), []byte{0x01, 0x02}...), // invalid encoding format errMsg: "invalid encoding format", }, @@ -122,14 +122,14 @@ func Test_Header_Validate(t *testing.T) { { name: "invalid encoding format", header: memo.Header{ - EncodingFormat: memo.EncodingFmtMax, + EncodingFormat: memo.EncodingFmtInvalid, }, errMsg: "invalid encoding format", }, { name: "invalid operation code", header: memo.Header{ - OpCode: memo.OpCodeMax, + OpCode: memo.OpCodeInvalid, }, errMsg: "invalid operation code", }, diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 97629005ae..9917ba4e59 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -93,7 +93,7 @@ func Test_Memo_EncodeToBytes(t *testing.T) { name: "failed to encode memo header", memo: &memo.InboundMemo{ Header: memo.Header{ - OpCode: memo.OpCodeMax, // invalid operation code + OpCode: memo.OpCodeInvalid, // invalid operation code }, }, errMsg: "failed to encode memo header", @@ -222,7 +222,7 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }, { name: "failed to decode memo header", - head: sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeMax, 0, 0), + head: sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeInvalid, 0, 0), data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), errMsg: "failed to decode memo header", }, diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index 4bee3c24fd..ae526e3b56 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -10,6 +10,11 @@ import ( "github.com/zeta-chain/node/pkg/memo" ) +const ( + // abiAlignment is the number of bytes used to align the ABI encoded data + abiAlignment = 32 +) + // MemoHead is a helper function to create a memo head // Note: all arguments are assume to be <= 0b1111 for simplicity. func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { @@ -27,7 +32,7 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { packedData := make([]byte, 0) // data offset for 1st dynamic-length field - offset := memo.ABIAlignment * len(args) + offset := abiAlignment * len(args) // 1. pack 32-byte offset for each dynamic-length field (bytes, string) // 2. pack actual data for each fixed-length field (address) @@ -42,9 +47,9 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { argLen := len(arg.Arg.([]byte)) if argLen > 0 { - offset += memo.ABIAlignment * 2 // [length + data] + offset += abiAlignment * 2 // [length + data] } else { - offset += memo.ABIAlignment // only [length] + offset += abiAlignment // only [length] } case memo.ArgTypeString: @@ -56,9 +61,9 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { argLen := len([]byte(arg.Arg.(string))) if argLen > 0 { - offset += memo.ABIAlignment * 2 // [length + data] + offset += abiAlignment * 2 // [length + data] } else { - offset += memo.ABIAlignment // only [length] + offset += abiAlignment // only [length] } case memo.ArgTypeAddress: // left-pad for address @@ -127,7 +132,7 @@ func abiPad32(t *testing.T, data []byte, left bool) []byte { return []byte{} } - require.LessOrEqual(t, len(data), memo.ABIAlignment) + require.LessOrEqual(t, len(data), abiAlignment) padded := make([]byte, 32) if left { From 5b59e19a61af23fff0af5c223bb81554bf3c83e5 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 14:35:17 -0500 Subject: [PATCH 13/19] corrected typo and improved unit tests --- pkg/memo/codec_abi_test.go | 7 ++++++- pkg/memo/codec_compact.go | 4 ++-- pkg/memo/codec_compact_test.go | 11 +++++++++-- pkg/memo/fields_v0_test.go | 13 +++++++++---- pkg/memo/header.go | 2 +- pkg/memo/header_test.go | 2 +- pkg/memo/memo_test.go | 16 ++++++++-------- testutil/sample/memo.go | 6 +++--- 8 files changed, 39 insertions(+), 22 deletions(-) diff --git a/pkg/memo/codec_abi_test.go b/pkg/memo/codec_abi_test.go index f61f9ac774..b82608d929 100644 --- a/pkg/memo/codec_abi_test.go +++ b/pkg/memo/codec_abi_test.go @@ -33,7 +33,7 @@ func ensureArgEquality(t *testing.T, expected, actual interface{}) { case string: require.Equal(t, v, *actual.(*string)) default: - t.Fatalf("unexpected argument type: %T", v) + require.FailNow(t, "unexpected argument type", "Type: %T", v) } } @@ -48,6 +48,11 @@ func Test_CodecABI_AddArguments(t *testing.T) { address := sample.EthAddress() codec.AddArguments(memo.ArgReceiver(&address)) + + // attempt to pack the arguments, result should not be nil + packedData, err := codec.PackArguments() + require.NoError(t, err) + require.True(t, len(packedData) > 0) } func Test_CodecABI_PackArgument(t *testing.T) { diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go index f295c0e0ab..62963e50d8 100644 --- a/pkg/memo/codec_compact.go +++ b/pkg/memo/codec_compact.go @@ -52,11 +52,11 @@ func (c *CodecCompact) PackArguments() ([]byte, error) { } data = append(data, dataBytes...) case ArgTypeAddress: - dateAddress, err := c.packAddress(arg.Arg) + dataAddress, err := c.packAddress(arg.Arg) if err != nil { return nil, errors.Wrapf(err, "failed to pack address argument: %s", arg.Name) } - data = append(data, dateAddress...) + data = append(data, dataAddress...) case ArgTypeString: dataString, err := c.packString(arg.Arg) if err != nil { diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go index 046acfea3e..72600bf54f 100644 --- a/pkg/memo/codec_compact_test.go +++ b/pkg/memo/codec_compact_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/testutil/sample" @@ -41,11 +42,17 @@ func Test_NewCodecCompact(t *testing.T) { } func Test_CodecCompact_AddArguments(t *testing.T) { - codec := memo.NewCodecABI() + codec, err := memo.NewCodecCompact(memo.EncodingFmtCompactLong) + require.NoError(t, err) require.NotNil(t, codec) address := sample.EthAddress() - codec.AddArguments(memo.ArgReceiver(&address)) + codec.AddArguments(memo.ArgReceiver(address)) + + // attempt to pack the arguments, result should not be nil + packedData, err := codec.PackArguments() + require.NoError(t, err) + require.True(t, len(packedData) == common.AddressLength) } func Test_CodecCompact_PackArguments(t *testing.T) { diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index afcdce9842..38b6c08e55 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -11,6 +11,11 @@ import ( crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) +const ( + // flagsAllFieldsSet sets all fields: [payload, revert address, abort address, CallOnRevert] + flagsAllFieldsSet = 0b00001111 +) + func Test_V0_Pack(t *testing.T) { // create sample fields fAddress := sample.EthAddress() @@ -40,7 +45,7 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedFlags: 0b00001111, // all fields are set + expectedFlags: flagsAllFieldsSet, // all fields are set expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -62,7 +67,7 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedFlags: 0b00001111, // all fields are set + expectedFlags: flagsAllFieldsSet, // all fields are set expectedData: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), @@ -144,7 +149,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack all fields with ABI encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, - flags: 0b00001111, // all fields are set + flags: flagsAllFieldsSet, // all fields are set data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -166,7 +171,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack all fields with compact encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, - flags: 0b00001111, // all fields are set + flags: flagsAllFieldsSet, // all fields are set data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), diff --git a/pkg/memo/header.go b/pkg/memo/header.go index 87aa683b88..fc614eacab 100644 --- a/pkg/memo/header.go +++ b/pkg/memo/header.go @@ -88,7 +88,7 @@ func (h *Header) EncodeToBytes() ([]byte, error) { // DecodeFromBytes decodes the memo header from the given data func (h *Header) DecodeFromBytes(data []byte) error { // memo data must be longer than the header size - if len(data) <= HeaderSize { + if len(data) < HeaderSize { return errors.New("memo is too short") } diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go index c22c2c765a..d4db80e297 100644 --- a/pkg/memo/header_test.go +++ b/pkg/memo/header_test.go @@ -67,7 +67,7 @@ func Test_Header_DecodeFromBytes(t *testing.T) { }, { name: "memo is too short", - data: []byte{0x01, 0x02, 0x03, 0x04}, + data: []byte{0x01, 0x02, 0x03}, errMsg: "memo is too short", }, { diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 9917ba4e59..e439880f7c 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -46,8 +46,8 @@ func Test_Memo_EncodeToBytes(t *testing.T) { memo.EncodingFmtABI, memo.OpCodeDepositAndCall, 0, - 0b00001111, - ), // all fields are set + flagsAllFieldsSet, // all fields are set + ), expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -79,8 +79,8 @@ func Test_Memo_EncodeToBytes(t *testing.T) { memo.EncodingFmtCompactShort, memo.OpCodeDepositAndCall, 0, - 0b00001111, - ), // all fields are set + flagsAllFieldsSet, // all fields are set + ), expectedData: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), @@ -158,8 +158,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { memo.EncodingFmtABI, memo.OpCodeDepositAndCall, 0, - 0b00001111, - ), // all fields are set + flagsAllFieldsSet, // all fields are set + ), data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -192,8 +192,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { memo.EncodingFmtCompactLong, memo.OpCodeDepositAndCall, 0, - 0b00001111, - ), // all fields are set + flagsAllFieldsSet, // all fields are set + ), data: sample.CompactPack( memo.EncodingFmtCompactLong, memo.ArgReceiver(fAddress), diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index ae526e3b56..eb6a2351fa 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -16,7 +16,7 @@ const ( ) // MemoHead is a helper function to create a memo head -// Note: all arguments are assume to be <= 0b1111 for simplicity. +// Note: all arguments are assumed to be <= 0b1111 for simplicity. func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { head := make([]byte, memo.HeaderSize) head[0] = memo.Identifier @@ -26,7 +26,7 @@ func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { return head } -// ABIPack is a helper function to simulates the abi.Pack function. +// ABIPack is a helper function that simulates the abi.Pack function. // Note: all arguments are assumed to be <= 32 bytes for simplicity. func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { packedData := make([]byte, 0) @@ -145,7 +145,7 @@ func abiPad32(t *testing.T, data []byte, left bool) []byte { return padded } -// apiPackDynamicData is a helper function to pack dynamic-length data +// abiPackDynamicData is a helper function to pack dynamic-length data func abiPackDynamicData(t *testing.T, args ...memo.CodecArg) []byte { packedData := make([]byte, 0) From b3fabbda86643d5e558664212f87d2bd3b99dbb1 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 10 Oct 2024 16:36:24 -0500 Subject: [PATCH 14/19] fix build --- pkg/memo/codec_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/memo/codec_test.go b/pkg/memo/codec_test.go index 182414eeeb..6ac3a3cd0d 100644 --- a/pkg/memo/codec_test.go +++ b/pkg/memo/codec_test.go @@ -3,7 +3,7 @@ package memo_test import ( "testing" - "github.com/test-go/testify/require" + "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" ) From dd80da87f1fe7ebdb52baf15ca9c287c566f1028 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Fri, 11 Oct 2024 17:08:09 -0500 Subject: [PATCH 15/19] make receiver address optional --- pkg/memo/codec_compact.go | 6 +- pkg/memo/fields.go | 7 +- pkg/memo/fields_v0.go | 97 ++++++++++++++++-------- pkg/memo/fields_v0_test.go | 148 ++++++++++++++++++++++++------------- pkg/memo/header_test.go | 4 +- pkg/memo/memo.go | 11 +-- pkg/memo/memo_test.go | 4 +- 7 files changed, 183 insertions(+), 94 deletions(-) diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go index 62963e50d8..379e627a8f 100644 --- a/pkg/memo/codec_compact.go +++ b/pkg/memo/codec_compact.go @@ -233,8 +233,10 @@ func (c *CodecCompact) unpackBytes(data []byte, output interface{}) (int, error) } // make a copy of the data - *pSlice = make([]byte, dataLen) - copy(*pSlice, data[c.lenBytes:c.lenBytes+dataLen]) + if dataLen > 0 { + *pSlice = make([]byte, dataLen) + copy(*pSlice, data[c.lenBytes:c.lenBytes+dataLen]) + } return c.lenBytes + dataLen, nil } diff --git a/pkg/memo/fields.go b/pkg/memo/fields.go index 06def4a6b5..9baeab19af 100644 --- a/pkg/memo/fields.go +++ b/pkg/memo/fields.go @@ -3,11 +3,14 @@ package memo // Fields is the interface for memo fields type Fields interface { // Pack encodes the memo fields - Pack(opCode, encodingFormat uint8) (byte, []byte, error) + Pack(opCode, encodingFormat, dataFlags uint8) ([]byte, error) // Unpack decodes the memo fields Unpack(opCode, encodingFormat, dataFlags uint8, data []byte) error // Validate checks if the fields are valid - Validate(opCode uint8) error + Validate(opCode, dataFlags uint8) error + + // DataFlags build the data flags for the fields + DataFlags() uint8 } diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index 40d3977445..e6b6a5f986 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -11,15 +11,16 @@ import ( // Enum of the bit position of each memo fields const ( - bitPosPayload uint8 = 0 // payload - bitPosRevertAddress uint8 = 1 // revertAddress - bitPosAbortAddress uint8 = 2 // abortAddress - bitPosRevertMessage uint8 = 3 // revertMessage + bitPosReceiver uint8 = 0 // receiver + bitPosPayload uint8 = 1 // payload + bitPosRevertAddress uint8 = 2 // revertAddress + bitPosAbortAddress uint8 = 3 // abortAddress + bitPosRevertMessage uint8 = 4 // revertMessage ) const ( // MaskFlagsReserved is the mask for reserved data flags - MaskFlagsReserved = 0b11110000 + MaskFlagsReserved = 0b11100000 ) var _ Fields = (*FieldsV0)(nil) @@ -37,23 +38,23 @@ type FieldsV0 struct { } // Pack encodes the memo fields -func (f *FieldsV0) Pack(opCode uint8, encodingFormat uint8) (byte, []byte, error) { +func (f *FieldsV0) Pack(opCode uint8, encodingFormat uint8, dataFlags uint8) ([]byte, error) { // validate fields - err := f.Validate(opCode) + err := f.Validate(opCode, dataFlags) if err != nil { - return 0, nil, err + return nil, err } codec, err := GetCodec(encodingFormat) if err != nil { - return 0, nil, errors.Wrap(err, "unable to get codec") + return nil, errors.Wrap(err, "unable to get codec") } - return f.packFields(codec) + return f.packFields(codec, dataFlags) } // Unpack decodes the memo fields -func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags byte, data []byte) error { +func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags uint8, data []byte) error { codec, err := GetCodec(encodingFormat) if err != nil { return errors.Wrap(err, "unable to get codec") @@ -64,13 +65,13 @@ func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags byte, da return err } - return f.Validate(opCode) + return f.Validate(opCode, dataFlags) } // Validate checks if the fields are valid -func (f *FieldsV0) Validate(opCode uint8) error { - // check if receiver is empty - if crypto.IsEmptyAddress(f.Receiver) { +func (f *FieldsV0) Validate(opCode uint8, dataFlags uint8) error { + // receiver address must be a valid address + if zetamath.IsBitSet(dataFlags, bitPosReceiver) && crypto.IsEmptyAddress(f.Receiver) { return errors.New("receiver address is empty") } @@ -79,6 +80,11 @@ func (f *FieldsV0) Validate(opCode uint8) error { return errors.New("payload is not allowed for deposit operation") } + // abort address must be a valid address + if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) && !common.IsHexAddress(f.RevertOptions.AbortAddress) { + return errors.New("invalid abort address") + } + // revert message is not allowed when CallOnRevert is false // 1. it's a good-to-have check to make the fields semantically correct. // 2. unpacking won't hit this error as the codec will catch it earlier. @@ -89,52 +95,81 @@ func (f *FieldsV0) Validate(opCode uint8) error { return nil } -// packFieldsV0 packs the memo fields for version 0 -func (f *FieldsV0) packFields(codec Codec) (byte, []byte, error) { - // create data flags byte - var dataFlags byte +// DataFlags build the data flags from actual fields +func (f *FieldsV0) DataFlags() uint8 { + var dataFlags uint8 - // add 'receiver' as the first argument - codec.AddArguments(ArgReceiver(f.Receiver)) + // set 'receiver' flag if provided + if !crypto.IsEmptyAddress(f.Receiver) { + zetamath.SetBit(&dataFlags, bitPosReceiver) + } - // add 'payload' argument optionally + // set 'payload' flag if provided if len(f.Payload) > 0 { zetamath.SetBit(&dataFlags, bitPosPayload) - codec.AddArguments(ArgPayload(f.Payload)) } - // add 'revertAddress' argument optionally + // set 'revertAddress' flag if provided if f.RevertOptions.RevertAddress != "" { zetamath.SetBit(&dataFlags, bitPosRevertAddress) + } + + // set 'abortAddress' flag if provided + if f.RevertOptions.AbortAddress != "" { + zetamath.SetBit(&dataFlags, bitPosAbortAddress) + } + + // set 'revertMessage' flag if provided + if f.RevertOptions.CallOnRevert { + zetamath.SetBit(&dataFlags, bitPosRevertMessage) + } + + return dataFlags +} + +// packFieldsV0 packs the memo fields for version 0 +func (f *FieldsV0) packFields(codec Codec, dataFlags uint8) ([]byte, error) { + // add 'receiver' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosReceiver) { + codec.AddArguments(ArgReceiver(f.Receiver)) + } + + // add 'payload' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosPayload) { + codec.AddArguments(ArgPayload(f.Payload)) + } + + // add 'revertAddress' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosRevertAddress) { codec.AddArguments(ArgRevertAddress(f.RevertOptions.RevertAddress)) } // add 'abortAddress' argument optionally abortAddress := common.HexToAddress(f.RevertOptions.AbortAddress) - if !crypto.IsEmptyAddress(abortAddress) { - zetamath.SetBit(&dataFlags, bitPosAbortAddress) + if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) { codec.AddArguments(ArgAbortAddress(abortAddress)) } // add 'revertMessage' argument optionally if f.RevertOptions.CallOnRevert { - zetamath.SetBit(&dataFlags, bitPosRevertMessage) codec.AddArguments(ArgRevertMessage(f.RevertOptions.RevertMessage)) } // pack the codec arguments into data data, err := codec.PackArguments() if err != nil { // never happens - return 0, nil, errors.Wrap(err, "failed to pack arguments") + return nil, errors.Wrap(err, "failed to pack arguments") } - return dataFlags, data, nil + return data, nil } // unpackFields unpacks the memo fields for version 0 func (f *FieldsV0) unpackFields(codec Codec, dataFlags byte, data []byte) error { - // add 'receiver' as the first argument - codec.AddArguments(ArgReceiver(&f.Receiver)) + // add 'receiver' argument optionally + if zetamath.IsBitSet(dataFlags, bitPosReceiver) { + codec.AddArguments(ArgReceiver(&f.Receiver)) + } // add 'payload' argument optionally if zetamath.IsBitSet(dataFlags, bitPosPayload) { diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 38b6c08e55..124b7af70f 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -12,8 +12,8 @@ import ( ) const ( - // flagsAllFieldsSet sets all fields: [payload, revert address, abort address, CallOnRevert] - flagsAllFieldsSet = 0b00001111 + // flagsAllFieldsSet sets all fields: [receiver, payload, revert address, abort address, CallOnRevert] + flagsAllFieldsSet = 0b00011111 ) func Test_V0_Pack(t *testing.T) { @@ -26,6 +26,7 @@ func Test_V0_Pack(t *testing.T) { name string opCode uint8 encodingFormat uint8 + dataFlags uint8 fields memo.FieldsV0 expectedFlags byte expectedData []byte @@ -35,6 +36,7 @@ func Test_V0_Pack(t *testing.T) { name: "pack all fields with ABI encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, + dataFlags: flagsAllFieldsSet, // all fields are set fields: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -45,7 +47,6 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedFlags: flagsAllFieldsSet, // all fields are set expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -57,6 +58,7 @@ func Test_V0_Pack(t *testing.T) { name: "pack all fields with compact encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, + dataFlags: flagsAllFieldsSet, // all fields are set fields: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -67,7 +69,6 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedFlags: flagsAllFieldsSet, // all fields are set expectedData: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), @@ -76,23 +77,11 @@ func Test_V0_Pack(t *testing.T) { memo.ArgAbortAddress(fAddress), memo.ArgRevertMessage(fBytes)), }, - { - name: "should not pack invalid abort address", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - fields: memo.FieldsV0{ - Receiver: fAddress, - RevertOptions: crosschaintypes.RevertOptions{ - AbortAddress: "invalid_address", - }, - }, - expectedFlags: 0b00000000, // no flag is set - expectedData: sample.ABIPack(t, memo.ArgReceiver(fAddress)), - }, { name: "fields validation failed due to empty receiver address", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, + dataFlags: 0b00000001, // receiver flag is set fields: memo.FieldsV0{ Receiver: common.Address{}, }, @@ -112,19 +101,17 @@ func Test_V0_Pack(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // pack the fields - flags, data, err := tc.fields.Pack(tc.opCode, tc.encodingFormat) + data, err := tc.fields.Pack(tc.opCode, tc.encodingFormat, tc.dataFlags) // validate the error message if tc.errMsg != "" { require.ErrorContains(t, err, tc.errMsg) - require.Zero(t, flags) require.Nil(t, data) return } // compare the fields require.NoError(t, err) - require.Equal(t, tc.expectedFlags, flags) require.True(t, bytes.Equal(tc.expectedData, data)) }) } @@ -140,7 +127,7 @@ func Test_V0_Unpack(t *testing.T) { name string opCode uint8 encodingFormat uint8 - flags byte + dataFlags byte data []byte expected memo.FieldsV0 errMsg string @@ -149,7 +136,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack all fields with ABI encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, - flags: flagsAllFieldsSet, // all fields are set + dataFlags: flagsAllFieldsSet, // all fields are set data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -171,7 +158,7 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack all fields with compact encoding", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, - flags: flagsAllFieldsSet, // all fields are set + dataFlags: flagsAllFieldsSet, // all fields are set data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), @@ -194,34 +181,26 @@ func Test_V0_Unpack(t *testing.T) { name: "unpack empty ABI encoded payload if flag is set", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, - flags: 0b00000001, // payload flag is set + dataFlags: 0b00000010, // payload flags are set data: sample.ABIPack(t, - memo.ArgReceiver(fAddress), memo.ArgPayload([]byte{})), // empty payload - expected: memo.FieldsV0{ - Receiver: fAddress, - Payload: []byte{}, - }, + expected: memo.FieldsV0{}, }, { - name: "unpack empty compact encoded payload if flag is not set", + name: "unpack empty compact encoded payload if flag is set", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, - flags: 0b00000001, // payload flag is set + dataFlags: 0b00000010, // payload flag is set data: sample.CompactPack( memo.EncodingFmtCompactShort, - memo.ArgReceiver(fAddress), memo.ArgPayload([]byte{})), // empty payload - expected: memo.FieldsV0{ - Receiver: fAddress, - Payload: []byte{}, - }, + expected: memo.FieldsV0{}, }, { name: "unable to get codec on invalid encoding format", opCode: memo.OpCodeDepositAndCall, encodingFormat: 0x0F, - flags: 0b00000001, + dataFlags: 0b00000001, data: []byte{}, errMsg: "unable to get codec", }, @@ -229,7 +208,7 @@ func Test_V0_Unpack(t *testing.T) { name: "failed to unpack ABI encoded data with compact encoding format", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtCompactShort, - flags: 0b00000001, + dataFlags: 0b00000011, // receiver and payload flags are set data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes)), @@ -239,7 +218,7 @@ func Test_V0_Unpack(t *testing.T) { name: "fields validation failed due to empty receiver address", opCode: memo.OpCodeDepositAndCall, encodingFormat: memo.EncodingFmtABI, - flags: 0b00000001, + dataFlags: 0b00000011, // receiver and payload flags are set data: sample.ABIPack(t, memo.ArgReceiver(common.Address{}), memo.ArgPayload(fBytes)), @@ -251,7 +230,7 @@ func Test_V0_Unpack(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // unpack the fields fields := memo.FieldsV0{} - err := fields.Unpack(tc.opCode, tc.encodingFormat, tc.flags, tc.data) + err := fields.Unpack(tc.opCode, tc.encodingFormat, tc.dataFlags, tc.data) // validate the error message if tc.errMsg != "" { @@ -273,14 +252,16 @@ func Test_V0_Validate(t *testing.T) { fString := "this_is_a_string_field" tests := []struct { - name string - opCode uint8 - fields memo.FieldsV0 - errMsg string + name string + opCode uint8 + dataFlags uint8 + fields memo.FieldsV0 + errMsg string }{ { - name: "valid fields", - opCode: memo.OpCodeDepositAndCall, + name: "valid fields", + opCode: memo.OpCodeDepositAndCall, + dataFlags: flagsAllFieldsSet, // all fields are set fields: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -293,10 +274,11 @@ func Test_V0_Validate(t *testing.T) { }, }, { - name: "invalid receiver address", - opCode: memo.OpCodeCall, + name: "invalid receiver address", + opCode: memo.OpCodeCall, + dataFlags: 0b00000001, // receiver flag is set fields: memo.FieldsV0{ - Receiver: common.Address{}, // empty receiver address + Receiver: common.Address{}, // provide empty receiver address }, errMsg: "receiver address is empty", }, @@ -309,6 +291,17 @@ func Test_V0_Validate(t *testing.T) { }, errMsg: "payload is not allowed for deposit operation", }, + { + name: "abort address is invalid", + opCode: memo.OpCodeDeposit, + dataFlags: 0b00001000, // abort address flag is set + fields: memo.FieldsV0{ + RevertOptions: crosschaintypes.RevertOptions{ + AbortAddress: "invalid abort address", + }, + }, + errMsg: "invalid abort address", + }, { name: "revert message is not allowed when CallOnRevert is false", opCode: memo.OpCodeDeposit, @@ -327,7 +320,7 @@ func Test_V0_Validate(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // validate the fields - err := tc.fields.Validate(tc.opCode) + err := tc.fields.Validate(tc.opCode, tc.dataFlags) // validate the error message if tc.errMsg != "" { @@ -338,3 +331,58 @@ func Test_V0_Validate(t *testing.T) { }) } } + +func Test_V0_DataFlags(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + fields memo.FieldsV0 + expectedFlags uint8 + }{ + { + name: "all fields set", + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + expectedFlags: flagsAllFieldsSet, + }, + { + name: "no fields set", + fields: memo.FieldsV0{}, + expectedFlags: 0b00000000, + }, + { + name: "a few fields set", + fields: memo.FieldsV0{ + Receiver: fAddress, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + RevertMessage: fBytes, + }, + }, + expectedFlags: 0b00010101, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // get the data flags + flags := tc.fields.DataFlags() + + // compare the flags + require.Equal(t, tc.expectedFlags, flags) + }) + } +} diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go index d4db80e297..fbb706fc8e 100644 --- a/pkg/memo/header_test.go +++ b/pkg/memo/header_test.go @@ -21,9 +21,9 @@ func Test_Header_EncodeToBytes(t *testing.T) { Version: 0, EncodingFormat: memo.EncodingFmtABI, OpCode: memo.OpCodeCall, - DataFlags: 0b00001111, + DataFlags: 0b00011111, }, - expected: []byte{memo.Identifier, 0b00000000, 0b00100000, 0b00001111}, + expected: []byte{memo.Identifier, 0b00000000, 0b00100000, 0b00011111}, }, { name: "header validation failed", diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index 7aa201690b..152b078c5e 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -22,6 +22,10 @@ type InboundMemo struct { // - Any provided 'DataFlags' is ignored as they are calculated based on the fields set in the memo. // - The 'RevertGasLimit' is not used for now for non-EVM chains. func (m *InboundMemo) EncodeToBytes() ([]byte, error) { + // build fields flags + dataFlags := m.FieldsV0.DataFlags() + m.Header.DataFlags = dataFlags + // encode head head, err := m.Header.EncodeToBytes() if err != nil { @@ -32,7 +36,7 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { var data []byte switch m.Version { case 0: - m.DataFlags, data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFormat) + data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFormat, dataFlags) default: return nil, fmt.Errorf("invalid memo version: %d", m.Version) } @@ -40,9 +44,6 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", m.Version) } - // update data flags with the calculated value - head[3] = m.DataFlags - return append(head, data...), nil } @@ -61,7 +62,7 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { // decode fields based on version switch memo.Version { case 0: - err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFormat, memo.DataFlags, data[HeaderSize:]) + err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFormat, memo.Header.DataFlags, data[HeaderSize:]) default: return nil, fmt.Errorf("invalid memo version: %d", memo.Version) } diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index e439880f7c..10c8094e5c 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -171,7 +171,7 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { Version: 0, EncodingFormat: memo.EncodingFmtABI, OpCode: memo.OpCodeDepositAndCall, - DataFlags: 0b00001111, + DataFlags: 0b00011111, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -206,7 +206,7 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { Version: 0, EncodingFormat: memo.EncodingFmtCompactLong, OpCode: memo.OpCodeDepositAndCall, - DataFlags: 0b00001111, + DataFlags: 0b00011111, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, From d022be1053ebba9c11d5d3d5c51755b3705b4c03 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Tue, 15 Oct 2024 15:11:20 -0500 Subject: [PATCH 16/19] move bits.go to bits folder; type defines for OpCode and EncodingFormat; add more func descriptions --- pkg/math/{ => bits}/bits.go | 4 + pkg/math/{ => bits}/bits_test.go | 0 pkg/memo/codec.go | 20 ++--- pkg/memo/codec_compact.go | 5 +- pkg/memo/codec_compact_test.go | 128 +++++++++++++++---------------- pkg/memo/codec_test.go | 36 ++++----- pkg/memo/fields.go | 6 +- pkg/memo/fields_v0.go | 44 +++++------ pkg/memo/fields_v0_test.go | 124 +++++++++++++++--------------- pkg/memo/header.go | 54 ++++++------- pkg/memo/header_test.go | 30 ++++---- pkg/memo/memo.go | 4 +- pkg/memo/memo_test.go | 65 +++++++++------- testutil/sample/memo.go | 2 +- 14 files changed, 271 insertions(+), 251 deletions(-) rename pkg/math/{ => bits}/bits.go (86%) rename pkg/math/{ => bits}/bits_test.go (100%) diff --git a/pkg/math/bits.go b/pkg/math/bits/bits.go similarity index 86% rename from pkg/math/bits.go rename to pkg/math/bits/bits.go index b00e1539b0..a7a7865793 100644 --- a/pkg/math/bits.go +++ b/pkg/math/bits/bits.go @@ -9,6 +9,10 @@ func SetBit(b *byte, position uint8) { if position > 7 { return } + + // Example: given b = 0b00000000 and position = 3 + // step-1: shift value 1 to left by 3 times: 1 << 3 = 0b00001000 + // step-2: make an OR operation with original byte to set the bit: 0b00000000 | 0b00001000 = 0b00001000 *b |= 1 << position } diff --git a/pkg/math/bits_test.go b/pkg/math/bits/bits_test.go similarity index 100% rename from pkg/math/bits_test.go rename to pkg/math/bits/bits_test.go diff --git a/pkg/memo/codec.go b/pkg/memo/codec.go index d1f592f4b3..ea60fa24c1 100644 --- a/pkg/memo/codec.go +++ b/pkg/memo/codec.go @@ -4,19 +4,21 @@ import ( "fmt" ) +type EncodingFormat uint8 + // Enum for non-EVM chain memo encoding format (2 bits) const ( // EncodingFmtABI represents ABI encoding format - EncodingFmtABI uint8 = 0b0000 + EncodingFmtABI EncodingFormat = 0b0000 // EncodingFmtCompactShort represents 'compact short' encoding format - EncodingFmtCompactShort uint8 = 0b0001 + EncodingFmtCompactShort EncodingFormat = 0b0001 // EncodingFmtCompactLong represents 'compact long' encoding format - EncodingFmtCompactLong uint8 = 0b0010 + EncodingFmtCompactLong EncodingFormat = 0b0010 // EncodingFmtInvalid represents invalid encoding format - EncodingFmtInvalid uint8 = 0b0011 + EncodingFmtInvalid EncodingFormat = 0b0011 ) // Enum for length of bytes used to encode compact data @@ -38,7 +40,7 @@ type Codec interface { } // GetLenBytes returns the number of bytes used to encode the length of the data -func GetLenBytes(encodingFmt uint8) (int, error) { +func GetLenBytes(encodingFmt EncodingFormat) (int, error) { switch encodingFmt { case EncodingFmtCompactShort: return LenBytesShort, nil @@ -50,13 +52,13 @@ func GetLenBytes(encodingFmt uint8) (int, error) { } // GetCodec returns the codec based on the encoding format -func GetCodec(encodingFormat uint8) (Codec, error) { - switch encodingFormat { +func GetCodec(encodingFmt EncodingFormat) (Codec, error) { + switch encodingFmt { case EncodingFmtABI: return NewCodecABI(), nil case EncodingFmtCompactShort, EncodingFmtCompactLong: - return NewCodecCompact(encodingFormat) + return NewCodecCompact(encodingFmt) default: - return nil, fmt.Errorf("invalid encoding format %d", encodingFormat) + return nil, fmt.Errorf("invalid encoding format %d", encodingFmt) } } diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go index 379e627a8f..bcb49d3f92 100644 --- a/pkg/memo/codec_compact.go +++ b/pkg/memo/codec_compact.go @@ -12,6 +12,9 @@ import ( var _ Codec = (*CodecCompact)(nil) // CodecCompact is a coder/decoder for compact encoded memo fields +// +// This encoding format concatenates the memo fields into a single byte array +// with zero padding to minimize the total size of the memo. type CodecCompact struct { // lenBytes is the number of bytes used to encode the length of the data lenBytes int @@ -21,7 +24,7 @@ type CodecCompact struct { } // NewCodecCompact creates a new compact codec -func NewCodecCompact(encodingFmt uint8) (*CodecCompact, error) { +func NewCodecCompact(encodingFmt EncodingFormat) (*CodecCompact, error) { lenBytes, err := GetLenBytes(encodingFmt) if err != nil { return nil, err diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go index 72600bf54f..9cec7d122d 100644 --- a/pkg/memo/codec_compact_test.go +++ b/pkg/memo/codec_compact_test.go @@ -12,24 +12,24 @@ import ( func Test_NewCodecCompact(t *testing.T) { tests := []struct { - name string - encodingFmt uint8 - fail bool + name string + encodeFmt memo.EncodingFormat + fail bool }{ { - name: "create codec compact successfully", - encodingFmt: memo.EncodingFmtCompactShort, + name: "create codec compact successfully", + encodeFmt: memo.EncodingFmtCompactShort, }, { - name: "create codec compact failed on invalid encoding format", - encodingFmt: 0b11, - fail: true, + name: "create codec compact failed on invalid encoding format", + encodeFmt: 0b11, + fail: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - codec, err := memo.NewCodecCompact(tc.encodingFmt) + codec, err := memo.NewCodecCompact(tc.encodeFmt) if tc.fail { require.Error(t, err) require.Nil(t, codec) @@ -64,14 +64,14 @@ func Test_CodecCompact_PackArguments(t *testing.T) { // test cases tests := []struct { name string - encodingFmt uint8 + encodeFmt memo.EncodingFormat args []memo.CodecArg expectedLen int errMsg string }{ { - name: "pack arguments of [address, bytes, string] in compact-short format", - encodingFmt: memo.EncodingFmtCompactShort, + name: "pack arguments of [address, bytes, string] in compact-short format", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), @@ -80,8 +80,8 @@ func Test_CodecCompact_PackArguments(t *testing.T) { expectedLen: 20 + 1 + len(argBytes) + 1 + len([]byte(argString)), }, { - name: "pack arguments of [string, address, bytes] in compact-long format", - encodingFmt: memo.EncodingFmtCompactLong, + name: "pack arguments of [string, address, bytes] in compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, args: []memo.CodecArg{ memo.ArgRevertAddress(argString), memo.ArgReceiver(argAddress), @@ -90,32 +90,32 @@ func Test_CodecCompact_PackArguments(t *testing.T) { expectedLen: 2 + len([]byte(argString)) + 20 + 2 + len(argBytes), }, { - name: "pack long string (> 255 bytes) with compact-long format", - encodingFmt: memo.EncodingFmtCompactLong, + name: "pack long string (> 255 bytes) with compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, args: []memo.CodecArg{ memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), }, expectedLen: 2 + 256, }, { - name: "pack long string (> 255 bytes) with compact-short format should fail", - encodingFmt: memo.EncodingFmtCompactShort, + name: "pack long string (> 255 bytes) with compact-short format should fail", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), }, errMsg: "exceeds 255 bytes", }, { - name: "pack long string (> 65535 bytes) with compact-long format should fail", - encodingFmt: memo.EncodingFmtCompactLong, + name: "pack long string (> 65535 bytes) with compact-long format should fail", + encodeFmt: memo.EncodingFmtCompactLong, args: []memo.CodecArg{ memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 65536))), }, errMsg: "exceeds 65535 bytes", }, { - name: "pack empty byte array and string arguments", - encodingFmt: memo.EncodingFmtCompactShort, + name: "pack empty byte array and string arguments", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgPayload([]byte{}), memo.ArgRevertAddress(""), @@ -123,32 +123,32 @@ func Test_CodecCompact_PackArguments(t *testing.T) { expectedLen: 2, }, { - name: "failed to pack bytes argument if string is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to pack bytes argument if string is passed", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgPayload(argString), // expect bytes type, but passed string }, errMsg: "argument is not of type []byte", }, { - name: "failed to pack address argument if bytes is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to pack address argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgReceiver(argBytes), // expect address type, but passed bytes }, errMsg: "argument is not of type common.Address", }, { - name: "failed to pack string argument if bytes is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to pack string argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.ArgRevertAddress(argBytes), // expect string type, but passed bytes }, errMsg: "argument is not of type string", }, { - name: "failed to pack unsupported argument type", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to pack unsupported argument type", + encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ memo.NewArg("receiver", memo.ArgType("unknown"), nil), }, @@ -160,7 +160,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // create a new compact codec and add arguments - codec, err := memo.NewCodecCompact(tc.encodingFmt) + codec, err := memo.NewCodecCompact(tc.encodeFmt) require.NoError(t, err) codec.AddArguments(tc.args...) @@ -176,7 +176,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { require.Equal(t, tc.expectedLen, len(packedData)) // calc expected data for comparison - expectedData := sample.CompactPack(tc.encodingFmt, tc.args...) + expectedData := sample.CompactPack(tc.encodeFmt, tc.args...) // validate the packed data require.True(t, bytes.Equal(expectedData, packedData), "compact encoded data mismatch") @@ -192,15 +192,15 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { // test cases tests := []struct { - name string - encodingFmt uint8 - data []byte - expected []memo.CodecArg - errMsg string + name string + encodeFmt memo.EncodingFormat + data []byte + expected []memo.CodecArg + errMsg string }{ { - name: "unpack arguments of [address, bytes, string] in compact-short format", - encodingFmt: memo.EncodingFmtCompactShort, + name: "unpack arguments of [address, bytes, string] in compact-short format", + encodeFmt: memo.EncodingFmtCompactShort, data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), @@ -214,8 +214,8 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { }, }, { - name: "unpack arguments of [string, address, bytes] in compact-long format", - encodingFmt: memo.EncodingFmtCompactLong, + name: "unpack arguments of [string, address, bytes] in compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, data: sample.CompactPack( memo.EncodingFmtCompactLong, memo.ArgRevertAddress(argString), @@ -229,8 +229,8 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { }, }, { - name: "unpack empty byte array and string argument", - encodingFmt: memo.EncodingFmtCompactShort, + name: "unpack empty byte array and string argument", + encodeFmt: memo.EncodingFmtCompactShort, data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload([]byte{}), @@ -242,35 +242,35 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { }, }, { - name: "failed to unpack address if data length < 20 bytes", - encodingFmt: memo.EncodingFmtCompactShort, - data: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + name: "failed to unpack address if data length < 20 bytes", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, expected: []memo.CodecArg{ memo.ArgReceiver(argAddress), }, errMsg: "expected address, got 5 bytes", }, { - name: "failed to unpack string if data length < 1 byte", - encodingFmt: memo.EncodingFmtCompactShort, - data: []byte{}, + name: "failed to unpack string if data length < 1 byte", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{}, expected: []memo.CodecArg{ memo.ArgRevertAddress(argString), }, errMsg: "expected 1 bytes to decode length", }, { - name: "failed to unpack string if actual data is less than decoded length", - encodingFmt: memo.EncodingFmtCompactShort, - data: []byte{0x05, 0x0a, 0x0b, 0x0c, 0x0d}, // length = 5, but only 4 bytes provided + name: "failed to unpack string if actual data is less than decoded length", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{0x05, 0x0a, 0x0b, 0x0c, 0x0d}, // length = 5, but only 4 bytes provided expected: []memo.CodecArg{ memo.ArgPayload(argBytes), }, errMsg: "expected 5 bytes, got 4", }, { - name: "failed to unpack bytes argument if string is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to unpack bytes argument if string is passed", + encodeFmt: memo.EncodingFmtCompactShort, data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload(argBytes), @@ -281,8 +281,8 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { errMsg: "argument is not of type *[]byte", }, { - name: "failed to unpack address argument if bytes is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to unpack address argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), @@ -293,8 +293,8 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { errMsg: "argument is not of type *common.Address", }, { - name: "failed to unpack string argument if address is passed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "failed to unpack string argument if address is passed", + encodeFmt: memo.EncodingFmtCompactShort, data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgRevertAddress(argString), @@ -305,17 +305,17 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { errMsg: "argument is not of type *string", }, { - name: "failed to unpack unsupported argument type", - encodingFmt: memo.EncodingFmtCompactShort, - data: []byte{}, + name: "failed to unpack unsupported argument type", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{}, expected: []memo.CodecArg{ memo.NewArg("payload", memo.ArgType("unknown"), nil), }, errMsg: "unsupported argument (payload) type", }, { - name: "unpacking should fail if not all data is consumed", - encodingFmt: memo.EncodingFmtCompactShort, + name: "unpacking should fail if not all data is consumed", + encodeFmt: memo.EncodingFmtCompactShort, data: func() []byte { data := sample.CompactPack( memo.EncodingFmtCompactShort, @@ -337,7 +337,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // create a new compact codec and add arguments - codec, err := memo.NewCodecCompact(tc.encodingFmt) + codec, err := memo.NewCodecCompact(tc.encodeFmt) require.NoError(t, err) // add output arguments diff --git a/pkg/memo/codec_test.go b/pkg/memo/codec_test.go index 6ac3a3cd0d..aa27667967 100644 --- a/pkg/memo/codec_test.go +++ b/pkg/memo/codec_test.go @@ -11,23 +11,23 @@ func Test_GetLenBytes(t *testing.T) { // Define table-driven test cases tests := []struct { name string - encodingFmt uint8 + encodeFmt memo.EncodingFormat expectedLen int expectErr bool }{ { name: "compact short", - encodingFmt: memo.EncodingFmtCompactShort, + encodeFmt: memo.EncodingFmtCompactShort, expectedLen: 1, }, { name: "compact long", - encodingFmt: memo.EncodingFmtCompactLong, + encodeFmt: memo.EncodingFmtCompactLong, expectedLen: 2, }, { name: "non-compact encoding format", - encodingFmt: memo.EncodingFmtABI, + encodeFmt: memo.EncodingFmtABI, expectedLen: 0, expectErr: true, }, @@ -36,7 +36,7 @@ func Test_GetLenBytes(t *testing.T) { // Loop through each test case for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - length, err := memo.GetLenBytes(tc.encodingFmt) + length, err := memo.GetLenBytes(tc.encodeFmt) // Check if error is expected if tc.expectErr { @@ -53,33 +53,33 @@ func Test_GetLenBytes(t *testing.T) { func Test_GetCodec(t *testing.T) { // Define table-driven test cases tests := []struct { - name string - encodingFmt uint8 - errMsg string + name string + encodeFmt memo.EncodingFormat + errMsg string }{ { - name: "should get ABI codec", - encodingFmt: memo.EncodingFmtABI, + name: "should get ABI codec", + encodeFmt: memo.EncodingFmtABI, }, { - name: "should get compact codec", - encodingFmt: memo.EncodingFmtCompactShort, + name: "should get compact codec", + encodeFmt: memo.EncodingFmtCompactShort, }, { - name: "should get compact codec", - encodingFmt: memo.EncodingFmtCompactLong, + name: "should get compact codec", + encodeFmt: memo.EncodingFmtCompactLong, }, { - name: "should fail to get codec", - encodingFmt: 0b0011, - errMsg: "invalid encoding format", + name: "should fail to get codec", + encodeFmt: 0b0011, + errMsg: "invalid encoding format", }, } // Loop through each test case for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - codec, err := memo.GetCodec(tc.encodingFmt) + codec, err := memo.GetCodec(tc.encodeFmt) if tc.errMsg != "" { require.Error(t, err) require.Nil(t, codec) diff --git a/pkg/memo/fields.go b/pkg/memo/fields.go index 9baeab19af..fff853f955 100644 --- a/pkg/memo/fields.go +++ b/pkg/memo/fields.go @@ -3,13 +3,13 @@ package memo // Fields is the interface for memo fields type Fields interface { // Pack encodes the memo fields - Pack(opCode, encodingFormat, dataFlags uint8) ([]byte, error) + Pack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8) ([]byte, error) // Unpack decodes the memo fields - Unpack(opCode, encodingFormat, dataFlags uint8, data []byte) error + Unpack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8, data []byte) error // Validate checks if the fields are valid - Validate(opCode, dataFlags uint8) error + Validate(opCode OpCode, dataFlags uint8) error // DataFlags build the data flags for the fields DataFlags() uint8 diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go index e6b6a5f986..a8f79d99ba 100644 --- a/pkg/memo/fields_v0.go +++ b/pkg/memo/fields_v0.go @@ -5,7 +5,7 @@ import ( "github.com/pkg/errors" "github.com/zeta-chain/node/pkg/crypto" - zetamath "github.com/zeta-chain/node/pkg/math" + zetabits "github.com/zeta-chain/node/pkg/math/bits" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) @@ -38,14 +38,14 @@ type FieldsV0 struct { } // Pack encodes the memo fields -func (f *FieldsV0) Pack(opCode uint8, encodingFormat uint8, dataFlags uint8) ([]byte, error) { +func (f *FieldsV0) Pack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8) ([]byte, error) { // validate fields err := f.Validate(opCode, dataFlags) if err != nil { return nil, err } - codec, err := GetCodec(encodingFormat) + codec, err := GetCodec(encodingFmt) if err != nil { return nil, errors.Wrap(err, "unable to get codec") } @@ -54,8 +54,8 @@ func (f *FieldsV0) Pack(opCode uint8, encodingFormat uint8, dataFlags uint8) ([] } // Unpack decodes the memo fields -func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags uint8, data []byte) error { - codec, err := GetCodec(encodingFormat) +func (f *FieldsV0) Unpack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8, data []byte) error { + codec, err := GetCodec(encodingFmt) if err != nil { return errors.Wrap(err, "unable to get codec") } @@ -69,9 +69,9 @@ func (f *FieldsV0) Unpack(opCode uint8, encodingFormat uint8, dataFlags uint8, d } // Validate checks if the fields are valid -func (f *FieldsV0) Validate(opCode uint8, dataFlags uint8) error { +func (f *FieldsV0) Validate(opCode OpCode, dataFlags uint8) error { // receiver address must be a valid address - if zetamath.IsBitSet(dataFlags, bitPosReceiver) && crypto.IsEmptyAddress(f.Receiver) { + if zetabits.IsBitSet(dataFlags, bitPosReceiver) && crypto.IsEmptyAddress(f.Receiver) { return errors.New("receiver address is empty") } @@ -81,7 +81,7 @@ func (f *FieldsV0) Validate(opCode uint8, dataFlags uint8) error { } // abort address must be a valid address - if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) && !common.IsHexAddress(f.RevertOptions.AbortAddress) { + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) && !common.IsHexAddress(f.RevertOptions.AbortAddress) { return errors.New("invalid abort address") } @@ -101,27 +101,27 @@ func (f *FieldsV0) DataFlags() uint8 { // set 'receiver' flag if provided if !crypto.IsEmptyAddress(f.Receiver) { - zetamath.SetBit(&dataFlags, bitPosReceiver) + zetabits.SetBit(&dataFlags, bitPosReceiver) } // set 'payload' flag if provided if len(f.Payload) > 0 { - zetamath.SetBit(&dataFlags, bitPosPayload) + zetabits.SetBit(&dataFlags, bitPosPayload) } // set 'revertAddress' flag if provided if f.RevertOptions.RevertAddress != "" { - zetamath.SetBit(&dataFlags, bitPosRevertAddress) + zetabits.SetBit(&dataFlags, bitPosRevertAddress) } // set 'abortAddress' flag if provided if f.RevertOptions.AbortAddress != "" { - zetamath.SetBit(&dataFlags, bitPosAbortAddress) + zetabits.SetBit(&dataFlags, bitPosAbortAddress) } // set 'revertMessage' flag if provided if f.RevertOptions.CallOnRevert { - zetamath.SetBit(&dataFlags, bitPosRevertMessage) + zetabits.SetBit(&dataFlags, bitPosRevertMessage) } return dataFlags @@ -130,23 +130,23 @@ func (f *FieldsV0) DataFlags() uint8 { // packFieldsV0 packs the memo fields for version 0 func (f *FieldsV0) packFields(codec Codec, dataFlags uint8) ([]byte, error) { // add 'receiver' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosReceiver) { + if zetabits.IsBitSet(dataFlags, bitPosReceiver) { codec.AddArguments(ArgReceiver(f.Receiver)) } // add 'payload' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosPayload) { + if zetabits.IsBitSet(dataFlags, bitPosPayload) { codec.AddArguments(ArgPayload(f.Payload)) } // add 'revertAddress' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosRevertAddress) { + if zetabits.IsBitSet(dataFlags, bitPosRevertAddress) { codec.AddArguments(ArgRevertAddress(f.RevertOptions.RevertAddress)) } // add 'abortAddress' argument optionally abortAddress := common.HexToAddress(f.RevertOptions.AbortAddress) - if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) { + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) { codec.AddArguments(ArgAbortAddress(abortAddress)) } @@ -167,28 +167,28 @@ func (f *FieldsV0) packFields(codec Codec, dataFlags uint8) ([]byte, error) { // unpackFields unpacks the memo fields for version 0 func (f *FieldsV0) unpackFields(codec Codec, dataFlags byte, data []byte) error { // add 'receiver' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosReceiver) { + if zetabits.IsBitSet(dataFlags, bitPosReceiver) { codec.AddArguments(ArgReceiver(&f.Receiver)) } // add 'payload' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosPayload) { + if zetabits.IsBitSet(dataFlags, bitPosPayload) { codec.AddArguments(ArgPayload(&f.Payload)) } // add 'revertAddress' argument optionally - if zetamath.IsBitSet(dataFlags, bitPosRevertAddress) { + if zetabits.IsBitSet(dataFlags, bitPosRevertAddress) { codec.AddArguments(ArgRevertAddress(&f.RevertOptions.RevertAddress)) } // add 'abortAddress' argument optionally var abortAddress common.Address - if zetamath.IsBitSet(dataFlags, bitPosAbortAddress) { + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) { codec.AddArguments(ArgAbortAddress(&abortAddress)) } // add 'revertMessage' argument optionally - f.RevertOptions.CallOnRevert = zetamath.IsBitSet(dataFlags, bitPosRevertMessage) + f.RevertOptions.CallOnRevert = zetabits.IsBitSet(dataFlags, bitPosRevertMessage) if f.RevertOptions.CallOnRevert { codec.AddArguments(ArgRevertMessage(&f.RevertOptions.RevertMessage)) } diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 124b7af70f..1339f4370a 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -23,20 +23,20 @@ func Test_V0_Pack(t *testing.T) { fString := "this_is_a_string_field" tests := []struct { - name string - opCode uint8 - encodingFormat uint8 - dataFlags uint8 - fields memo.FieldsV0 - expectedFlags byte - expectedData []byte - errMsg string + name string + opCode memo.OpCode + encodeFmt memo.EncodingFormat + dataFlags uint8 + fields memo.FieldsV0 + expectedFlags byte + expectedData []byte + errMsg string }{ { - name: "pack all fields with ABI encoding", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - dataFlags: flagsAllFieldsSet, // all fields are set + name: "pack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: flagsAllFieldsSet, // all fields are set fields: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -55,10 +55,10 @@ func Test_V0_Pack(t *testing.T) { memo.ArgRevertMessage(fBytes)), }, { - name: "pack all fields with compact encoding", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtCompactShort, - dataFlags: flagsAllFieldsSet, // all fields are set + name: "pack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: flagsAllFieldsSet, // all fields are set fields: memo.FieldsV0{ Receiver: fAddress, Payload: fBytes, @@ -78,10 +78,10 @@ func Test_V0_Pack(t *testing.T) { memo.ArgRevertMessage(fBytes)), }, { - name: "fields validation failed due to empty receiver address", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - dataFlags: 0b00000001, // receiver flag is set + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000001, // receiver flag is set fields: memo.FieldsV0{ Receiver: common.Address{}, }, @@ -93,15 +93,15 @@ func Test_V0_Pack(t *testing.T) { fields: memo.FieldsV0{ Receiver: fAddress, }, - encodingFormat: 0x0F, - errMsg: "unable to get codec", + encodeFmt: 0x0F, + errMsg: "unable to get codec", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // pack the fields - data, err := tc.fields.Pack(tc.opCode, tc.encodingFormat, tc.dataFlags) + data, err := tc.fields.Pack(tc.opCode, tc.encodeFmt, tc.dataFlags) // validate the error message if tc.errMsg != "" { @@ -124,19 +124,19 @@ func Test_V0_Unpack(t *testing.T) { fString := "this_is_a_string_field" tests := []struct { - name string - opCode uint8 - encodingFormat uint8 - dataFlags byte - data []byte - expected memo.FieldsV0 - errMsg string + name string + opCode memo.OpCode + encodeFmt memo.EncodingFormat + dataFlags byte + data []byte + expected memo.FieldsV0 + errMsg string }{ { - name: "unpack all fields with ABI encoding", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - dataFlags: flagsAllFieldsSet, // all fields are set + name: "unpack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: flagsAllFieldsSet, // all fields are set data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -155,10 +155,10 @@ func Test_V0_Unpack(t *testing.T) { }, }, { - name: "unpack all fields with compact encoding", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtCompactShort, - dataFlags: flagsAllFieldsSet, // all fields are set + name: "unpack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: flagsAllFieldsSet, // all fields are set data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), @@ -178,47 +178,47 @@ func Test_V0_Unpack(t *testing.T) { }, }, { - name: "unpack empty ABI encoded payload if flag is set", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - dataFlags: 0b00000010, // payload flags are set + name: "unpack empty ABI encoded payload if flag is set", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000010, // payload flags are set data: sample.ABIPack(t, memo.ArgPayload([]byte{})), // empty payload expected: memo.FieldsV0{}, }, { - name: "unpack empty compact encoded payload if flag is set", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtCompactShort, - dataFlags: 0b00000010, // payload flag is set + name: "unpack empty compact encoded payload if flag is set", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: 0b00000010, // payload flag is set data: sample.CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload([]byte{})), // empty payload expected: memo.FieldsV0{}, }, { - name: "unable to get codec on invalid encoding format", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: 0x0F, - dataFlags: 0b00000001, - data: []byte{}, - errMsg: "unable to get codec", + name: "unable to get codec on invalid encoding format", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: 0x0F, + dataFlags: 0b00000001, + data: []byte{}, + errMsg: "unable to get codec", }, { - name: "failed to unpack ABI encoded data with compact encoding format", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtCompactShort, - dataFlags: 0b00000011, // receiver and payload flags are set + name: "failed to unpack ABI encoded data with compact encoding format", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: 0b00000011, // receiver and payload flags are set data: sample.ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes)), errMsg: "failed to unpack arguments", }, { - name: "fields validation failed due to empty receiver address", - opCode: memo.OpCodeDepositAndCall, - encodingFormat: memo.EncodingFmtABI, - dataFlags: 0b00000011, // receiver and payload flags are set + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000011, // receiver and payload flags are set data: sample.ABIPack(t, memo.ArgReceiver(common.Address{}), memo.ArgPayload(fBytes)), @@ -230,7 +230,7 @@ func Test_V0_Unpack(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // unpack the fields fields := memo.FieldsV0{} - err := fields.Unpack(tc.opCode, tc.encodingFormat, tc.dataFlags, tc.data) + err := fields.Unpack(tc.opCode, tc.encodeFmt, tc.dataFlags, tc.data) // validate the error message if tc.errMsg != "" { @@ -253,7 +253,7 @@ func Test_V0_Validate(t *testing.T) { tests := []struct { name string - opCode uint8 + opCode memo.OpCode dataFlags uint8 fields memo.FieldsV0 errMsg string diff --git a/pkg/memo/header.go b/pkg/memo/header.go index fc614eacab..5f9bcd1e14 100644 --- a/pkg/memo/header.go +++ b/pkg/memo/header.go @@ -5,9 +5,11 @@ import ( "github.com/pkg/errors" - zetamath "github.com/zeta-chain/node/pkg/math" + zetabits "github.com/zeta-chain/node/pkg/math/bits" ) +type OpCode uint8 + const ( // Identifier is the ASCII code of 'Z' (0x5A) Identifier byte = 0x5A @@ -15,25 +17,25 @@ const ( // HeaderSize is the size of the memo header: [identifier + ctrlByte1+ ctrlByte2 + dataFlags] HeaderSize = 4 - // MaskVersion is the mask for the version bits(upper 4 bits) - MaskVersion byte = 0b11110000 + // maskVersion is the mask for the version bits(upper 4 bits) + maskVersion byte = 0b11110000 - // MaskEncodingFormat is the mask for the encoding format bits(lower 4 bits) - MaskEncodingFormat byte = 0b00001111 + // maskEncodingFormat is the mask for the encoding format bits(lower 4 bits) + maskEncodingFormat byte = 0b00001111 - // MaskOpCode is the mask for the operation code bits(upper 4 bits) - MaskOpCode byte = 0b11110000 + // maskOpCode is the mask for the operation code bits(upper 4 bits) + maskOpCode byte = 0b11110000 - // MaskCtrlReserved is the mask for reserved control bits (lower 4 bits) - MaskCtrlReserved byte = 0b00001111 + // maskCtrlReserved is the mask for reserved control bits (lower 4 bits) + maskCtrlReserved byte = 0b00001111 ) // Enum for non-EVM chain inbound operation code (4 bits) const ( - OpCodeDeposit uint8 = 0b0000 // operation 'deposit' - OpCodeDepositAndCall uint8 = 0b0001 // operation 'deposit_and_call' - OpCodeCall uint8 = 0b0010 // operation 'call' - OpCodeInvalid uint8 = 0b0011 // invalid operation code + OpCodeDeposit OpCode = 0b0000 // operation 'deposit' + OpCodeDepositAndCall OpCode = 0b0001 // operation 'deposit_and_call' + OpCodeCall OpCode = 0b0010 // operation 'call' + OpCodeInvalid OpCode = 0b0011 // invalid operation code ) // Header represent the memo header @@ -41,11 +43,11 @@ type Header struct { // Version is the memo Version Version uint8 - // EncodingFormat is the memo encoding format - EncodingFormat uint8 + // EncodingFmt is the memo encoding format + EncodingFmt EncodingFormat // OpCode is the inbound operation code - OpCode uint8 + OpCode OpCode // Reserved is the reserved control bits Reserved uint8 @@ -69,14 +71,14 @@ func (h *Header) EncodeToBytes() ([]byte, error) { // set version #, encoding format var ctrlByte1 byte - ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskVersion, h.Version) - ctrlByte1 = zetamath.SetBits(ctrlByte1, MaskEncodingFormat, h.EncodingFormat) + ctrlByte1 = zetabits.SetBits(ctrlByte1, maskVersion, h.Version) + ctrlByte1 = zetabits.SetBits(ctrlByte1, maskEncodingFormat, byte(h.EncodingFmt)) data[1] = ctrlByte1 // set operation code, reserved bits var ctrlByte2 byte - ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskOpCode, h.OpCode) - ctrlByte2 = zetamath.SetBits(ctrlByte2, MaskCtrlReserved, h.Reserved) + ctrlByte2 = zetabits.SetBits(ctrlByte2, maskOpCode, byte(h.OpCode)) + ctrlByte2 = zetabits.SetBits(ctrlByte2, maskCtrlReserved, h.Reserved) data[2] = ctrlByte2 // set data flags @@ -99,13 +101,13 @@ func (h *Header) DecodeFromBytes(data []byte) error { // extract version #, encoding format ctrlByte1 := data[1] - h.Version = zetamath.GetBits(ctrlByte1, MaskVersion) - h.EncodingFormat = zetamath.GetBits(ctrlByte1, MaskEncodingFormat) + h.Version = zetabits.GetBits(ctrlByte1, maskVersion) + h.EncodingFmt = EncodingFormat(zetabits.GetBits(ctrlByte1, maskEncodingFormat)) // extract operation code, reserved bits ctrlByte2 := data[2] - h.OpCode = zetamath.GetBits(ctrlByte2, MaskOpCode) - h.Reserved = zetamath.GetBits(ctrlByte2, MaskCtrlReserved) + h.OpCode = OpCode(zetabits.GetBits(ctrlByte2, maskOpCode)) + h.Reserved = zetabits.GetBits(ctrlByte2, maskCtrlReserved) // extract data flags h.DataFlags = data[3] @@ -120,8 +122,8 @@ func (h *Header) Validate() error { return fmt.Errorf("invalid memo version: %d", h.Version) } - if h.EncodingFormat >= EncodingFmtInvalid { - return fmt.Errorf("invalid encoding format: %d", h.EncodingFormat) + if h.EncodingFmt >= EncodingFmtInvalid { + return fmt.Errorf("invalid encoding format: %d", h.EncodingFmt) } if h.OpCode >= OpCodeInvalid { diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go index fbb706fc8e..31693a7c18 100644 --- a/pkg/memo/header_test.go +++ b/pkg/memo/header_test.go @@ -18,10 +18,10 @@ func Test_Header_EncodeToBytes(t *testing.T) { { name: "it works", header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeCall, - DataFlags: 0b00011111, + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + DataFlags: 0b00011111, }, expected: []byte{memo.Identifier, 0b00000000, 0b00100000, 0b00011111}, }, @@ -57,12 +57,14 @@ func Test_Header_DecodeFromBytes(t *testing.T) { }{ { name: "it works", - data: append(sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeCall, 0, 0), []byte{0x01, 0x02}...), + data: append( + sample.MemoHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeCall), 0, 0), + []byte{0x01, 0x02}...), expected: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeCall, - Reserved: 0, + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + Reserved: 0, }, }, { @@ -78,7 +80,7 @@ func Test_Header_DecodeFromBytes(t *testing.T) { { name: "header validation failed", data: append( - sample.MemoHead(0, memo.EncodingFmtInvalid, memo.OpCodeCall, 0, 0), + sample.MemoHead(0, uint8(memo.EncodingFmtInvalid), uint8(memo.OpCodeCall), 0, 0), []byte{0x01, 0x02}...), // invalid encoding format errMsg: "invalid encoding format", }, @@ -107,9 +109,9 @@ func Test_Header_Validate(t *testing.T) { { name: "valid header", header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactShort, - OpCode: memo.OpCodeDepositAndCall, + Version: 0, + EncodingFmt: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, }, }, { @@ -122,7 +124,7 @@ func Test_Header_Validate(t *testing.T) { { name: "invalid encoding format", header: memo.Header{ - EncodingFormat: memo.EncodingFmtInvalid, + EncodingFmt: memo.EncodingFmtInvalid, }, errMsg: "invalid encoding format", }, diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index 152b078c5e..420bcac6fa 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -36,7 +36,7 @@ func (m *InboundMemo) EncodeToBytes() ([]byte, error) { var data []byte switch m.Version { case 0: - data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFormat, dataFlags) + data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFmt, dataFlags) default: return nil, fmt.Errorf("invalid memo version: %d", m.Version) } @@ -62,7 +62,7 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { // decode fields based on version switch memo.Version { case 0: - err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFormat, memo.Header.DataFlags, data[HeaderSize:]) + err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFmt, memo.Header.DataFlags, data[HeaderSize:]) default: return nil, fmt.Errorf("invalid memo version: %d", memo.Version) } diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 10c8094e5c..1de361e267 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -26,9 +26,9 @@ func Test_Memo_EncodeToBytes(t *testing.T) { name: "encode memo with ABI encoding", memo: &memo.InboundMemo{ Header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeDepositAndCall, + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -43,8 +43,8 @@ func Test_Memo_EncodeToBytes(t *testing.T) { }, expectedHead: sample.MemoHead( 0, - memo.EncodingFmtABI, - memo.OpCodeDepositAndCall, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), @@ -59,9 +59,9 @@ func Test_Memo_EncodeToBytes(t *testing.T) { name: "encode memo with compact encoding", memo: &memo.InboundMemo{ Header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactShort, - OpCode: memo.OpCodeDepositAndCall, + Version: 0, + EncodingFmt: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -76,8 +76,8 @@ func Test_Memo_EncodeToBytes(t *testing.T) { }, expectedHead: sample.MemoHead( 0, - memo.EncodingFmtCompactShort, - memo.OpCodeDepositAndCall, + uint8(memo.EncodingFmtCompactShort), + uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), @@ -111,9 +111,9 @@ func Test_Memo_EncodeToBytes(t *testing.T) { name: "failed to pack memo fields", memo: &memo.InboundMemo{ Header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeDeposit, + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDeposit, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -134,6 +134,11 @@ func Test_Memo_EncodeToBytes(t *testing.T) { } require.NoError(t, err) require.Equal(t, append(tt.expectedHead, tt.expectedData...), data) + + // decode the memo and compare with the original + decodedMemo, err := memo.DecodeFromBytes(data) + require.NoError(t, err) + require.Equal(t, tt.memo, decodedMemo) }) } } @@ -148,6 +153,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { name string head []byte data []byte + headHex string + dataHex string expectedMemo memo.InboundMemo errMsg string }{ @@ -155,8 +162,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { name: "decode memo with ABI encoding", head: sample.MemoHead( 0, - memo.EncodingFmtABI, - memo.OpCodeDepositAndCall, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), @@ -168,10 +175,10 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { memo.ArgRevertMessage(fBytes)), expectedMemo: memo.InboundMemo{ Header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtABI, - OpCode: memo.OpCodeDepositAndCall, - DataFlags: 0b00011111, + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00011111, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -189,8 +196,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { name: "decode memo with compact encoding", head: sample.MemoHead( 0, - memo.EncodingFmtCompactLong, - memo.OpCodeDepositAndCall, + uint8(memo.EncodingFmtCompactLong), + uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), @@ -203,10 +210,10 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { memo.ArgRevertMessage(fBytes)), expectedMemo: memo.InboundMemo{ Header: memo.Header{ - Version: 0, - EncodingFormat: memo.EncodingFmtCompactLong, - OpCode: memo.OpCodeDepositAndCall, - DataFlags: 0b00011111, + Version: 0, + EncodingFmt: memo.EncodingFmtCompactLong, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00011111, }, FieldsV0: memo.FieldsV0{ Receiver: fAddress, @@ -222,13 +229,13 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }, { name: "failed to decode memo header", - head: sample.MemoHead(0, memo.EncodingFmtABI, memo.OpCodeInvalid, 0, 0), + head: sample.MemoHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeInvalid), 0, 0), data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), errMsg: "failed to decode memo header", }, { name: "failed to decode if version is invalid", - head: sample.MemoHead(1, memo.EncodingFmtABI, memo.OpCodeDeposit, 0, 0), + head: sample.MemoHead(1, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDeposit), 0, 0), data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), errMsg: "invalid memo version", }, @@ -236,8 +243,8 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { name: "failed to decode compact encoded data with ABI encoding format", head: sample.MemoHead( 0, - memo.EncodingFmtABI, - memo.OpCodeDepositAndCall, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), 0, 0, ), // header says ABI encoding diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go index eb6a2351fa..a70ce3537c 100644 --- a/testutil/sample/memo.go +++ b/testutil/sample/memo.go @@ -81,7 +81,7 @@ func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { // CompactPack is a helper function to pack arguments into compact encoded data // Note: all arguments are assumed to be <= 65535 bytes for simplicity. -func CompactPack(encodingFmt uint8, args ...memo.CodecArg) []byte { +func CompactPack(encodingFmt memo.EncodingFormat, args ...memo.CodecArg) []byte { var ( length int packedData []byte From 798f5669fc1b7684fc6d17549e56d67aa539f588 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Tue, 15 Oct 2024 17:24:46 -0500 Subject: [PATCH 17/19] move legacy Bitcoin memo decoding to memo package --- pkg/chains/conversion.go | 25 ------------ pkg/chains/utils_test.go | 37 ----------------- pkg/memo/memo.go | 24 +++++++++++ pkg/memo/memo_test.go | 40 ++++++++++++++++++- x/crosschain/keeper/evm_deposit.go | 3 +- zetaclient/chains/bitcoin/observer/inbound.go | 4 +- zetaclient/chains/evm/observer/inbound.go | 6 +-- zetaclient/compliance/compliance.go | 4 +- 8 files changed, 71 insertions(+), 72 deletions(-) diff --git a/pkg/chains/conversion.go b/pkg/chains/conversion.go index 3200b76693..5d03b77007 100644 --- a/pkg/chains/conversion.go +++ b/pkg/chains/conversion.go @@ -1,10 +1,8 @@ package chains import ( - "encoding/hex" "fmt" - "cosmossdk.io/errors" "github.com/btcsuite/btcd/chaincfg/chainhash" ethcommon "github.com/ethereum/go-ethereum/common" ) @@ -28,26 +26,3 @@ func StringToHash(chainID int64, hash string, additionalChains []Chain) ([]byte, } return nil, fmt.Errorf("cannot convert hash to bytes for chain %d", chainID) } - -// ParseAddressAndData parses the message string into an address and data -// message is hex encoded byte array -// [ contractAddress calldata ] -// [ 20B, variable] -func ParseAddressAndData(message string) (ethcommon.Address, []byte, error) { - if len(message) == 0 { - return ethcommon.Address{}, nil, nil - } - - data, err := hex.DecodeString(message) - if err != nil { - return ethcommon.Address{}, nil, errors.Wrap(err, "message should be a hex encoded string") - } - - if len(data) < 20 { - return ethcommon.Address{}, data, nil - } - - address := ethcommon.BytesToAddress(data[:20]) - data = data[20:] - return address, data, nil -} diff --git a/pkg/chains/utils_test.go b/pkg/chains/utils_test.go index 915fa2b8ca..7552e10967 100644 --- a/pkg/chains/utils_test.go +++ b/pkg/chains/utils_test.go @@ -1,7 +1,6 @@ package chains import ( - "encoding/hex" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -66,39 +65,3 @@ func TestStringToHash(t *testing.T) { }) } } - -func TestParseAddressAndData(t *testing.T) { - expectedShortMsgResult, err := hex.DecodeString("1a2b3c4d5e6f708192a3b4c5d6e7f808") - require.NoError(t, err) - tests := []struct { - name string - message string - expectAddr ethcommon.Address - expectData []byte - wantErr bool - }{ - { - "valid msg", - "95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5", - ethcommon.HexToAddress("95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5"), - []byte{}, - false, - }, - {"empty msg", "", ethcommon.Address{}, nil, false}, - {"invalid hex", "invalidHex", ethcommon.Address{}, nil, true}, - {"short msg", "1a2b3c4d5e6f708192a3b4c5d6e7f808", ethcommon.Address{}, expectedShortMsgResult, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - addr, data, err := ParseAddressAndData(tt.message) - if tt.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectAddr, addr) - require.Equal(t, tt.expectData, data) - } - }) - } -} diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go index 420bcac6fa..30670952b0 100644 --- a/pkg/memo/memo.go +++ b/pkg/memo/memo.go @@ -1,8 +1,10 @@ package memo import ( + "encoding/hex" "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" ) @@ -72,3 +74,25 @@ func DecodeFromBytes(data []byte) (*InboundMemo, error) { return memo, nil } + +// DecodeLegacyMemoHex decodes hex encoded memo message into address and calldata +// +// The layout of legacy memo is: [20-byte address, variable calldata] +func DecodeLegacyMemoHex(message string) (common.Address, []byte, error) { + if len(message) == 0 { + return common.Address{}, nil, nil + } + + data, err := hex.DecodeString(message) + if err != nil { + return common.Address{}, nil, errors.Wrap(err, "message should be a hex encoded string") + } + + if len(data) < common.AddressLength { + return common.Address{}, data, nil + } + + address := common.BytesToAddress(data[:common.AddressLength]) + data = data[common.AddressLength:] + return address, data, nil +} diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 1de361e267..50e240933e 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -1,8 +1,10 @@ package memo_test import ( + "encoding/hex" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/testutil/sample" @@ -153,8 +155,6 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { name string head []byte data []byte - headHex string - dataHex string expectedMemo memo.InboundMemo errMsg string }{ @@ -269,3 +269,39 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }) } } + +func Test_DecodeLegacyMemoHex(t *testing.T) { + expectedShortMsgResult, err := hex.DecodeString("1a2b3c4d5e6f708192a3b4c5d6e7f808") + require.NoError(t, err) + tests := []struct { + name string + message string + expectAddr common.Address + expectData []byte + wantErr bool + }{ + { + "valid msg", + "95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5", + common.HexToAddress("95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5"), + []byte{}, + false, + }, + {"empty msg", "", common.Address{}, nil, false}, + {"invalid hex", "invalidHex", common.Address{}, nil, true}, + {"short msg", "1a2b3c4d5e6f708192a3b4c5d6e7f808", common.Address{}, expectedShortMsgResult, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, data, err := memo.DecodeLegacyMemoHex(tt.message) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectAddr, addr) + require.Equal(t, tt.expectData, data) + } + }) + } +} diff --git a/x/crosschain/keeper/evm_deposit.go b/x/crosschain/keeper/evm_deposit.go index c672686c1a..be08873bba 100644 --- a/x/crosschain/keeper/evm_deposit.go +++ b/x/crosschain/keeper/evm_deposit.go @@ -14,6 +14,7 @@ import ( "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" + "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/x/crosschain/types" fungibletypes "github.com/zeta-chain/node/x/fungible/types" ) @@ -78,7 +79,7 @@ func (k Keeper) HandleEVMDeposit(ctx sdk.Context, cctx *types.CrossChainTx) (boo // in protocol version 2, the destination of the deposit is always the to address, the message is the data to be sent to the contract if cctx.ProtocolContractVersion == types.ProtocolContractVersion_V1 { var parsedAddress ethcommon.Address - parsedAddress, message, err = chains.ParseAddressAndData(cctx.RelayedMessage) + parsedAddress, message, err = memo.DecodeLegacyMemoHex(cctx.RelayedMessage) if err != nil { return false, errors.Wrap(types.ErrUnableToParseAddress, err.Error()) } diff --git a/zetaclient/chains/bitcoin/observer/inbound.go b/zetaclient/chains/bitcoin/observer/inbound.go index fad3c87230..1461096763 100644 --- a/zetaclient/chains/bitcoin/observer/inbound.go +++ b/zetaclient/chains/bitcoin/observer/inbound.go @@ -14,8 +14,8 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" - "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" + "github.com/zeta-chain/node/pkg/memo" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/bitcoin" "github.com/zeta-chain/node/zetaclient/chains/interfaces" @@ -419,7 +419,7 @@ func (ob *Observer) GetInboundVoteMessageFromBtcEvent(inbound *BTCInboundEvent) // TODO(revamp): move all compliance related functions in a specific file func (ob *Observer) DoesInboundContainsRestrictedAddress(inTx *BTCInboundEvent) bool { receiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(inTx.MemoBytes)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(inTx.MemoBytes)) if err == nil && parsedAddress != (ethcommon.Address{}) { receiver = parsedAddress.Hex() } diff --git a/zetaclient/chains/evm/observer/inbound.go b/zetaclient/chains/evm/observer/inbound.go index c662fc2d60..aebd661ae3 100644 --- a/zetaclient/chains/evm/observer/inbound.go +++ b/zetaclient/chains/evm/observer/inbound.go @@ -20,9 +20,9 @@ import ( "github.com/zeta-chain/protocol-contracts/v1/pkg/contracts/evm/erc20custody.sol" "github.com/zeta-chain/protocol-contracts/v1/pkg/contracts/evm/zetaconnector.non-eth.sol" - "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" "github.com/zeta-chain/node/pkg/constant" + "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/pkg/ticker" "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/evm" @@ -622,7 +622,7 @@ func (ob *Observer) BuildInboundVoteMsgForDepositedEvent( ) *types.MsgVoteInbound { // compliance check maybeReceiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(event.Message)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(event.Message)) if err == nil && parsedAddress != (ethcommon.Address{}) { maybeReceiver = parsedAddress.Hex() } @@ -732,7 +732,7 @@ func (ob *Observer) BuildInboundVoteMsgForTokenSentToTSS( // compliance check maybeReceiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(message) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(message) if err == nil && parsedAddress != (ethcommon.Address{}) { maybeReceiver = parsedAddress.Hex() } diff --git a/zetaclient/compliance/compliance.go b/zetaclient/compliance/compliance.go index 163cca65e4..f0135c3ad9 100644 --- a/zetaclient/compliance/compliance.go +++ b/zetaclient/compliance/compliance.go @@ -7,7 +7,7 @@ import ( ethcommon "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" - "github.com/zeta-chain/node/pkg/chains" + "github.com/zeta-chain/node/pkg/memo" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/base" "github.com/zeta-chain/node/zetaclient/config" @@ -66,7 +66,7 @@ func PrintComplianceLog( func DoesInboundContainsRestrictedAddress(event *clienttypes.InboundEvent, logger *base.ObserverLogger) bool { // parse memo-specified receiver receiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(event.Memo)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(event.Memo)) if err == nil && parsedAddress != (ethcommon.Address{}) { receiver = parsedAddress.Hex() } From 4f5b1a1f1bc078359bfb411a73ad948f999e3cc5 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Tue, 15 Oct 2024 18:01:10 -0500 Subject: [PATCH 18/19] move sample functions ABIPack, CompactPack into memo pkg self; remove sample package reference from memo to minimize dependency --- pkg/memo/arg_test.go | 8 +- pkg/memo/codec_abi_test.go | 131 ++++++++++++++++++++++-- pkg/memo/codec_compact_test.go | 77 +++++++++++--- pkg/memo/fields_v0_test.go | 22 ++-- pkg/memo/header_test.go | 16 ++- pkg/memo/memo_test.go | 33 +++--- testutil/sample/memo.go | 181 --------------------------------- 7 files changed, 228 insertions(+), 240 deletions(-) delete mode 100644 testutil/sample/memo.go diff --git a/pkg/memo/arg_test.go b/pkg/memo/arg_test.go index 5fa9e6a87c..ee9373e688 100644 --- a/pkg/memo/arg_test.go +++ b/pkg/memo/arg_test.go @@ -3,15 +3,15 @@ package memo_test import ( "testing" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" - "github.com/zeta-chain/node/testutil/sample" ) func Test_NewArg(t *testing.T) { - argAddress := sample.EthAddress() - argString := sample.String() - argBytes := sample.Bytes() + argAddress := common.HexToAddress("0x0B85C56e5453e0f4273d1D1BF3091d43B08B38CE") + argString := "some other string argument" + argBytes := []byte("here is a bytes argument") tests := []struct { name string diff --git a/pkg/memo/codec_abi_test.go b/pkg/memo/codec_abi_test.go index b82608d929..c4a2e1decd 100644 --- a/pkg/memo/codec_abi_test.go +++ b/pkg/memo/codec_abi_test.go @@ -2,14 +2,127 @@ package memo_test import ( "bytes" + "encoding/binary" "testing" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" - "github.com/zeta-chain/node/testutil/sample" ) +const ( + // abiAlignment is the number of bytes used to align the ABI encoded data + abiAlignment = 32 +) + +// ABIPack is a helper function that simulates the abi.Pack function. +// Note: all arguments are assumed to be <= 32 bytes for simplicity. +func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // data offset for 1st dynamic-length field + offset := abiAlignment * len(args) + + // 1. pack 32-byte offset for each dynamic-length field (bytes, string) + // 2. pack actual data for each fixed-length field (address) + for _, arg := range args { + switch arg.Type { + case memo.ArgTypeBytes: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) + packedData = append(packedData, offsetData...) + + argLen := len(arg.Arg.([]byte)) + if argLen > 0 { + offset += abiAlignment * 2 // [length + data] + } else { + offset += abiAlignment // only [length] + } + + case memo.ArgTypeString: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) + packedData = append(packedData, offsetData...) + + argLen := len([]byte(arg.Arg.(string))) + if argLen > 0 { + offset += abiAlignment * 2 // [length + data] + } else { + offset += abiAlignment // only [length] + } + + case memo.ArgTypeAddress: // left-pad for address + data := abiPad32(t, arg.Arg.(common.Address).Bytes(), true) + packedData = append(packedData, data...) + } + } + + // pack dynamic-length fields + dynamicData := abiPackDynamicData(t, args...) + packedData = append(packedData, dynamicData...) + + return packedData +} + +// abiPad32 is a helper function to pad a byte slice to 32 bytes +func abiPad32(t *testing.T, data []byte, left bool) []byte { + // nothing needs to be encoded, return empty bytes + if len(data) == 0 { + return []byte{} + } + + require.LessOrEqual(t, len(data), abiAlignment) + padded := make([]byte, 32) + + if left { + // left-pad the data for fixed-size types + copy(padded[32-len(data):], data) + } else { + // right-pad the data for dynamic types + copy(padded, data) + } + return padded +} + +// abiPackDynamicData is a helper function to pack dynamic-length data +func abiPackDynamicData(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // pack with ABI format: length + data + for _, arg := range args { + // get length + var length int + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + continue + } + + // append length in bytes + lengthData := abiPad32(t, []byte{byte(length)}, true) + packedData = append(packedData, lengthData...) + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: // right-pad for bytes + data := abiPad32(t, arg.Arg.([]byte), false) + packedData = append(packedData, data...) + case memo.ArgTypeString: // right-pad for string + data := abiPad32(t, []byte(arg.Arg.(string)), false) + packedData = append(packedData, data...) + } + } + + return packedData +} + // newArgInstance creates a new instance of the given argument type func newArgInstance(v interface{}) interface{} { switch v.(type) { @@ -46,7 +159,7 @@ func Test_CodecABI_AddArguments(t *testing.T) { codec := memo.NewCodecABI() require.NotNil(t, codec) - address := sample.EthAddress() + address := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") codec.AddArguments(memo.ArgReceiver(&address)) // attempt to pack the arguments, result should not be nil @@ -57,7 +170,7 @@ func Test_CodecABI_AddArguments(t *testing.T) { func Test_CodecABI_PackArgument(t *testing.T) { // create sample arguments - argAddress := sample.EthAddress() + argAddress := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") argBytes := []byte("some test bytes argument") argString := "some test string argument" @@ -124,7 +237,7 @@ func Test_CodecABI_PackArgument(t *testing.T) { require.NoError(t, err) // calc expected data for comparison - expectedData := sample.ABIPack(t, tc.args...) + expectedData := ABIPack(t, tc.args...) // validate the packed data require.True(t, bytes.Equal(expectedData, packedData), "ABI encoded data mismatch") @@ -134,7 +247,7 @@ func Test_CodecABI_PackArgument(t *testing.T) { func Test_CodecABI_UnpackArguments(t *testing.T) { // create sample arguments - argAddress := sample.EthAddress() + argAddress := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") argBytes := []byte("some test bytes argument") argString := "some test string argument" @@ -147,7 +260,7 @@ func Test_CodecABI_UnpackArguments(t *testing.T) { }{ { name: "unpack in the order of [address, bytes, string]", - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), memo.ArgRevertAddress(argString)), @@ -159,7 +272,7 @@ func Test_CodecABI_UnpackArguments(t *testing.T) { }, { name: "unpack in the order of [string, address, bytes]", - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgRevertAddress(argString), memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes)), @@ -171,7 +284,7 @@ func Test_CodecABI_UnpackArguments(t *testing.T) { }, { name: "unpack empty bytes array and string", - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgPayload([]byte{}), memo.ArgRevertAddress("")), expected: []memo.CodecArg{ @@ -190,7 +303,7 @@ func Test_CodecABI_UnpackArguments(t *testing.T) { }, { name: "unpacking should fail on argument type mismatch", - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(argAddress), ), expected: []memo.CodecArg{ diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go index 9cec7d122d..328beeb7e4 100644 --- a/pkg/memo/codec_compact_test.go +++ b/pkg/memo/codec_compact_test.go @@ -2,14 +2,61 @@ package memo_test import ( "bytes" + "encoding/binary" + "strings" "testing" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" - "github.com/zeta-chain/node/testutil/sample" ) +// CompactPack is a helper function to pack arguments into compact encoded data +// Note: all arguments are assumed to be <= 65535 bytes for simplicity. +func CompactPack(encodingFmt memo.EncodingFormat, args ...memo.CodecArg) []byte { + var ( + length int + packedData []byte + ) + + for _, arg := range args { + // get length of argument + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + // skip length for other types + length = -1 + } + + // append length in bytes + if length != -1 { + switch encodingFmt { + case memo.EncodingFmtCompactShort: + packedData = append(packedData, byte(length)) + case memo.EncodingFmtCompactLong: + buff := make([]byte, 2) + binary.LittleEndian.PutUint16(buff, uint16(length)) + packedData = append(packedData, buff...) + } + } + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: + packedData = append(packedData, arg.Arg.([]byte)...) + case memo.ArgTypeAddress: + packedData = append(packedData, arg.Arg.(common.Address).Bytes()...) + case memo.ArgTypeString: + packedData = append(packedData, []byte(arg.Arg.(string))...) + } + } + + return packedData +} + func Test_NewCodecCompact(t *testing.T) { tests := []struct { name string @@ -46,7 +93,7 @@ func Test_CodecCompact_AddArguments(t *testing.T) { require.NoError(t, err) require.NotNil(t, codec) - address := sample.EthAddress() + address := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") codec.AddArguments(memo.ArgReceiver(address)) // attempt to pack the arguments, result should not be nil @@ -57,7 +104,7 @@ func Test_CodecCompact_AddArguments(t *testing.T) { func Test_CodecCompact_PackArguments(t *testing.T) { // create sample arguments - argAddress := sample.EthAddress() + argAddress := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") argBytes := []byte("here is a bytes argument") argString := "some other string argument" @@ -93,7 +140,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { name: "pack long string (> 255 bytes) with compact-long format", encodeFmt: memo.EncodingFmtCompactLong, args: []memo.CodecArg{ - memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), + memo.ArgPayload([]byte(strings.Repeat("a", 256))), }, expectedLen: 2 + 256, }, @@ -101,7 +148,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { name: "pack long string (> 255 bytes) with compact-short format should fail", encodeFmt: memo.EncodingFmtCompactShort, args: []memo.CodecArg{ - memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 256))), + memo.ArgPayload([]byte(strings.Repeat("b", 256))), }, errMsg: "exceeds 255 bytes", }, @@ -109,7 +156,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { name: "pack long string (> 65535 bytes) with compact-long format should fail", encodeFmt: memo.EncodingFmtCompactLong, args: []memo.CodecArg{ - memo.ArgPayload([]byte(sample.StringRandom(sample.Rand(), 65536))), + memo.ArgPayload([]byte(strings.Repeat("c", 65536))), }, errMsg: "exceeds 65535 bytes", }, @@ -176,7 +223,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { require.Equal(t, tc.expectedLen, len(packedData)) // calc expected data for comparison - expectedData := sample.CompactPack(tc.encodeFmt, tc.args...) + expectedData := CompactPack(tc.encodeFmt, tc.args...) // validate the packed data require.True(t, bytes.Equal(expectedData, packedData), "compact encoded data mismatch") @@ -186,7 +233,7 @@ func Test_CodecCompact_PackArguments(t *testing.T) { func Test_CodecCompact_UnpackArguments(t *testing.T) { // create sample arguments - argAddress := sample.EthAddress() + argAddress := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") argBytes := []byte("some test bytes argument") argString := "some other string argument" @@ -201,7 +248,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack arguments of [address, bytes, string] in compact-short format", encodeFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), @@ -216,7 +263,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack arguments of [string, address, bytes] in compact-long format", encodeFmt: memo.EncodingFmtCompactLong, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactLong, memo.ArgRevertAddress(argString), memo.ArgReceiver(argAddress), @@ -231,7 +278,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "unpack empty byte array and string argument", encodeFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload([]byte{}), memo.ArgRevertAddress(""), @@ -271,7 +318,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack bytes argument if string is passed", encodeFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload(argBytes), ), @@ -283,7 +330,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack address argument if bytes is passed", encodeFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), ), @@ -295,7 +342,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { { name: "failed to unpack string argument if address is passed", encodeFmt: memo.EncodingFmtCompactShort, - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgRevertAddress(argString), ), @@ -317,7 +364,7 @@ func Test_CodecCompact_UnpackArguments(t *testing.T) { name: "unpacking should fail if not all data is consumed", encodeFmt: memo.EncodingFmtCompactShort, data: func() []byte { - data := sample.CompactPack( + data := CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(argAddress), memo.ArgPayload(argBytes), diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go index 1339f4370a..13742422c2 100644 --- a/pkg/memo/fields_v0_test.go +++ b/pkg/memo/fields_v0_test.go @@ -47,7 +47,7 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedData: sample.ABIPack(t, + expectedData: ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), memo.ArgRevertAddress(fString), @@ -69,7 +69,7 @@ func Test_V0_Pack(t *testing.T) { RevertMessage: fBytes, }, }, - expectedData: sample.CompactPack( + expectedData: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -119,7 +119,7 @@ func Test_V0_Pack(t *testing.T) { func Test_V0_Unpack(t *testing.T) { // create sample fields - fAddress := sample.EthAddress() + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") fBytes := []byte("here_s_some_bytes_field") fString := "this_is_a_string_field" @@ -137,7 +137,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtABI, dataFlags: flagsAllFieldsSet, // all fields are set - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), memo.ArgRevertAddress(fString), @@ -159,7 +159,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtCompactShort, dataFlags: flagsAllFieldsSet, // all fields are set - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -182,7 +182,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtABI, dataFlags: 0b00000010, // payload flags are set - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgPayload([]byte{})), // empty payload expected: memo.FieldsV0{}, }, @@ -191,7 +191,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtCompactShort, dataFlags: 0b00000010, // payload flag is set - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgPayload([]byte{})), // empty payload expected: memo.FieldsV0{}, @@ -209,7 +209,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtCompactShort, dataFlags: 0b00000011, // receiver and payload flags are set - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes)), errMsg: "failed to unpack arguments", @@ -219,7 +219,7 @@ func Test_V0_Unpack(t *testing.T) { opCode: memo.OpCodeDepositAndCall, encodeFmt: memo.EncodingFmtABI, dataFlags: 0b00000011, // receiver and payload flags are set - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(common.Address{}), memo.ArgPayload(fBytes)), errMsg: "receiver address is empty", @@ -247,7 +247,7 @@ func Test_V0_Unpack(t *testing.T) { func Test_V0_Validate(t *testing.T) { // create sample fields - fAddress := sample.EthAddress() + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") fBytes := []byte("here_s_some_bytes_field") fString := "this_is_a_string_field" @@ -334,7 +334,7 @@ func Test_V0_Validate(t *testing.T) { func Test_V0_DataFlags(t *testing.T) { // create sample fields - fAddress := sample.EthAddress() + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") fBytes := []byte("here_s_some_bytes_field") fString := "this_is_a_string_field" diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go index 31693a7c18..9d89e1f1d6 100644 --- a/pkg/memo/header_test.go +++ b/pkg/memo/header_test.go @@ -5,9 +5,19 @@ import ( "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" - "github.com/zeta-chain/node/testutil/sample" ) +// MakeHead is a helper function to create a memo head +// Note: all arguments are assumed to be <= 0b1111 for simplicity. +func MakeHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { + head := make([]byte, memo.HeaderSize) + head[0] = memo.Identifier + head[1] = version<<4 | encodingFmt + head[2] = opCode<<4 | reserved + head[3] = flags + return head +} + func Test_Header_EncodeToBytes(t *testing.T) { tests := []struct { name string @@ -58,7 +68,7 @@ func Test_Header_DecodeFromBytes(t *testing.T) { { name: "it works", data: append( - sample.MemoHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeCall), 0, 0), + MakeHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeCall), 0, 0), []byte{0x01, 0x02}...), expected: memo.Header{ Version: 0, @@ -80,7 +90,7 @@ func Test_Header_DecodeFromBytes(t *testing.T) { { name: "header validation failed", data: append( - sample.MemoHead(0, uint8(memo.EncodingFmtInvalid), uint8(memo.OpCodeCall), 0, 0), + MakeHead(0, uint8(memo.EncodingFmtInvalid), uint8(memo.OpCodeCall), 0, 0), []byte{0x01, 0x02}...), // invalid encoding format errMsg: "invalid encoding format", }, diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go index 50e240933e..e6cb067793 100644 --- a/pkg/memo/memo_test.go +++ b/pkg/memo/memo_test.go @@ -7,13 +7,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" "github.com/zeta-chain/node/pkg/memo" - "github.com/zeta-chain/node/testutil/sample" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" ) func Test_Memo_EncodeToBytes(t *testing.T) { // create sample fields - fAddress := sample.EthAddress() + fAddress := common.HexToAddress("0xEA9808f0Ac504d1F521B5BbdfC33e6f1953757a7") fBytes := []byte("here_s_some_bytes_field") fString := "this_is_a_string_field" @@ -43,14 +42,14 @@ func Test_Memo_EncodeToBytes(t *testing.T) { }, }, }, - expectedHead: sample.MemoHead( + expectedHead: MakeHead( 0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), - expectedData: sample.ABIPack(t, + expectedData: ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), memo.ArgRevertAddress(fString), @@ -76,14 +75,14 @@ func Test_Memo_EncodeToBytes(t *testing.T) { }, }, }, - expectedHead: sample.MemoHead( + expectedHead: MakeHead( 0, uint8(memo.EncodingFmtCompactShort), uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), - expectedData: sample.CompactPack( + expectedData: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -147,7 +146,7 @@ func Test_Memo_EncodeToBytes(t *testing.T) { func Test_Memo_DecodeFromBytes(t *testing.T) { // create sample fields - fAddress := sample.EthAddress() + fAddress := common.HexToAddress("0xEA9808f0Ac504d1F521B5BbdfC33e6f1953757a7") fBytes := []byte("here_s_some_bytes_field") fString := "this_is_a_string_field" @@ -160,14 +159,14 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }{ { name: "decode memo with ABI encoding", - head: sample.MemoHead( + head: MakeHead( 0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), - data: sample.ABIPack(t, + data: ABIPack(t, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), memo.ArgRevertAddress(fString), @@ -194,14 +193,14 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }, { name: "decode memo with compact encoding", - head: sample.MemoHead( + head: MakeHead( 0, uint8(memo.EncodingFmtCompactLong), uint8(memo.OpCodeDepositAndCall), 0, flagsAllFieldsSet, // all fields are set ), - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactLong, memo.ArgReceiver(fAddress), memo.ArgPayload(fBytes), @@ -229,26 +228,26 @@ func Test_Memo_DecodeFromBytes(t *testing.T) { }, { name: "failed to decode memo header", - head: sample.MemoHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeInvalid), 0, 0), - data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), + head: MakeHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeInvalid), 0, 0), + data: ABIPack(t, memo.ArgReceiver(fAddress)), errMsg: "failed to decode memo header", }, { name: "failed to decode if version is invalid", - head: sample.MemoHead(1, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDeposit), 0, 0), - data: sample.ABIPack(t, memo.ArgReceiver(fAddress)), + head: MakeHead(1, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDeposit), 0, 0), + data: ABIPack(t, memo.ArgReceiver(fAddress)), errMsg: "invalid memo version", }, { name: "failed to decode compact encoded data with ABI encoding format", - head: sample.MemoHead( + head: MakeHead( 0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDepositAndCall), 0, 0, ), // header says ABI encoding - data: sample.CompactPack( + data: CompactPack( memo.EncodingFmtCompactShort, memo.ArgReceiver(fAddress), ), // but data is compact encoded diff --git a/testutil/sample/memo.go b/testutil/sample/memo.go deleted file mode 100644 index a70ce3537c..0000000000 --- a/testutil/sample/memo.go +++ /dev/null @@ -1,181 +0,0 @@ -package sample - -import ( - "encoding/binary" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/stretchr/testify/require" - - "github.com/zeta-chain/node/pkg/memo" -) - -const ( - // abiAlignment is the number of bytes used to align the ABI encoded data - abiAlignment = 32 -) - -// MemoHead is a helper function to create a memo head -// Note: all arguments are assumed to be <= 0b1111 for simplicity. -func MemoHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { - head := make([]byte, memo.HeaderSize) - head[0] = memo.Identifier - head[1] = version<<4 | encodingFmt - head[2] = opCode<<4 | reserved - head[3] = flags - return head -} - -// ABIPack is a helper function that simulates the abi.Pack function. -// Note: all arguments are assumed to be <= 32 bytes for simplicity. -func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { - packedData := make([]byte, 0) - - // data offset for 1st dynamic-length field - offset := abiAlignment * len(args) - - // 1. pack 32-byte offset for each dynamic-length field (bytes, string) - // 2. pack actual data for each fixed-length field (address) - for _, arg := range args { - switch arg.Type { - case memo.ArgTypeBytes: - // left-pad length as uint16 - buff := make([]byte, 2) - binary.BigEndian.PutUint16(buff, uint16(offset)) - offsetData := abiPad32(t, buff, true) - packedData = append(packedData, offsetData...) - - argLen := len(arg.Arg.([]byte)) - if argLen > 0 { - offset += abiAlignment * 2 // [length + data] - } else { - offset += abiAlignment // only [length] - } - - case memo.ArgTypeString: - // left-pad length as uint16 - buff := make([]byte, 2) - binary.BigEndian.PutUint16(buff, uint16(offset)) - offsetData := abiPad32(t, buff, true) - packedData = append(packedData, offsetData...) - - argLen := len([]byte(arg.Arg.(string))) - if argLen > 0 { - offset += abiAlignment * 2 // [length + data] - } else { - offset += abiAlignment // only [length] - } - - case memo.ArgTypeAddress: // left-pad for address - data := abiPad32(t, arg.Arg.(common.Address).Bytes(), true) - packedData = append(packedData, data...) - } - } - - // pack dynamic-length fields - dynamicData := abiPackDynamicData(t, args...) - packedData = append(packedData, dynamicData...) - - return packedData -} - -// CompactPack is a helper function to pack arguments into compact encoded data -// Note: all arguments are assumed to be <= 65535 bytes for simplicity. -func CompactPack(encodingFmt memo.EncodingFormat, args ...memo.CodecArg) []byte { - var ( - length int - packedData []byte - ) - - for _, arg := range args { - // get length of argument - switch arg.Type { - case memo.ArgTypeBytes: - length = len(arg.Arg.([]byte)) - case memo.ArgTypeString: - length = len([]byte(arg.Arg.(string))) - default: - // skip length for other types - length = -1 - } - - // append length in bytes - if length != -1 { - switch encodingFmt { - case memo.EncodingFmtCompactShort: - packedData = append(packedData, byte(length)) - case memo.EncodingFmtCompactLong: - buff := make([]byte, 2) - binary.LittleEndian.PutUint16(buff, uint16(length)) - packedData = append(packedData, buff...) - } - } - - // append actual data in bytes - switch arg.Type { - case memo.ArgTypeBytes: - packedData = append(packedData, arg.Arg.([]byte)...) - case memo.ArgTypeAddress: - packedData = append(packedData, arg.Arg.(common.Address).Bytes()...) - case memo.ArgTypeString: - packedData = append(packedData, []byte(arg.Arg.(string))...) - } - } - - return packedData -} - -// abiPad32 is a helper function to pad a byte slice to 32 bytes -func abiPad32(t *testing.T, data []byte, left bool) []byte { - // nothing needs to be encoded, return empty bytes - if len(data) == 0 { - return []byte{} - } - - require.LessOrEqual(t, len(data), abiAlignment) - padded := make([]byte, 32) - - if left { - // left-pad the data for fixed-size types - copy(padded[32-len(data):], data) - } else { - // right-pad the data for dynamic types - copy(padded, data) - } - return padded -} - -// abiPackDynamicData is a helper function to pack dynamic-length data -func abiPackDynamicData(t *testing.T, args ...memo.CodecArg) []byte { - packedData := make([]byte, 0) - - // pack with ABI format: length + data - for _, arg := range args { - // get length - var length int - switch arg.Type { - case memo.ArgTypeBytes: - length = len(arg.Arg.([]byte)) - case memo.ArgTypeString: - length = len([]byte(arg.Arg.(string))) - default: - continue - } - - // append length in bytes - lengthData := abiPad32(t, []byte{byte(length)}, true) - packedData = append(packedData, lengthData...) - - // append actual data in bytes - switch arg.Type { - case memo.ArgTypeBytes: // right-pad for bytes - data := abiPad32(t, arg.Arg.([]byte), false) - packedData = append(packedData, data...) - case memo.ArgTypeString: // right-pad for string - data := abiPad32(t, []byte(arg.Arg.(string)), false) - packedData = append(packedData, data...) - } - } - - return packedData -} From b3426af1cd6ad1cdea054c605ca7d069856478ee Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Tue, 15 Oct 2024 22:22:42 -0500 Subject: [PATCH 19/19] fix unit test compile error --- pkg/math/bits/bits_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/math/bits/bits_test.go b/pkg/math/bits/bits_test.go index 445555a9a2..65632d0a24 100644 --- a/pkg/math/bits/bits_test.go +++ b/pkg/math/bits/bits_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" - "github.com/zeta-chain/node/pkg/math" + zetabits "github.com/zeta-chain/node/pkg/math/bits" ) func TestSetBit(t *testing.T) { @@ -37,7 +37,7 @@ func TestSetBit(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := tt.initial - math.SetBit(&b, tt.position) + zetabits.SetBit(&b, tt.position) require.Equal(t, tt.expected, b) }) } @@ -78,7 +78,7 @@ func TestIsBitSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := math.IsBitSet(tt.b, tt.position) + result := zetabits.IsBitSet(tt.b, tt.position) require.Equal(t, tt.expected, result) }) } @@ -125,7 +125,7 @@ func TestGetBits(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := math.GetBits(tt.b, tt.mask) + result := zetabits.GetBits(tt.b, tt.mask) require.Equal(t, tt.expected, result) }) } @@ -164,7 +164,7 @@ func TestSetBits(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := math.SetBits(tt.b, tt.mask, tt.value) + result := zetabits.SetBits(tt.b, tt.mask, tt.value) require.Equal(t, tt.expected, result) }) }