diff --git a/pkg/codec/by_item_type_modifier.go b/pkg/codec/by_item_type_modifier.go index f68e73e9d..af0c24a21 100644 --- a/pkg/codec/by_item_type_modifier.go +++ b/pkg/codec/by_item_type_modifier.go @@ -14,39 +14,75 @@ func NewByItemTypeModifier(modByItemType map[string]Modifier) (Modifier, error) } return &byItemTypeModifier{ - modByitemType: modByItemType, + modByItemType: modByItemType, + enableNesting: false, + }, nil +} + +// NewNestableByItemTypeModifier returns a Modifier that uses modByItemType to determine which Modifier to use for a +// given itemType. If itemType is structured as a dot-separated string like 'A.B.C', the first part 'A' will be used to +// match in the mod map and the remaining list will be provided to the found Modifier 'B.C'. +func NewNestableByItemTypeModifier(modByItemType map[string]Modifier) (Modifier, error) { + if modByItemType == nil { + modByItemType = map[string]Modifier{} + } + + return &byItemTypeModifier{ + modByItemType: modByItemType, + enableNesting: true, }, nil } type byItemTypeModifier struct { - modByitemType map[string]Modifier + modByItemType map[string]Modifier + enableNesting bool } -func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { - mod, ok := b.modByitemType[itemType] +// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this +// function returns an error if a modifier by the specified name is not found. If nesting is enabled, the itemType can +// be of the form `Path.To.Type` and this modifier will attempt to only match on `Path` to find a valid modifier. +func (m *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) { + head := itemType + tail := itemType + + if m.enableNesting { + head, tail = ItemTyper(itemType).Next() + } + + mod, ok := m.modByItemType[head] if !ok { return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } - return mod.RetypeToOffChain(onChainType, itemType) + return mod.RetypeToOffChain(onChainType, tail) } -func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { - return b.transform(offChainValue, itemType, Modifier.TransformToOnChain) +func (m *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { + return m.transform(offChainValue, itemType, Modifier.TransformToOnChain) } -func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { - return b.transform(onChainValue, itemType, Modifier.TransformToOffChain) +func (m *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { + return m.transform(onChainValue, itemType, Modifier.TransformToOffChain) } -func (b *byItemTypeModifier) transform( - val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) { - mod, ok := b.modByitemType[itemType] +func (m *byItemTypeModifier) transform( + val any, + itemType string, + transform func(Modifier, any, string) (any, error), +) (any, error) { + head := itemType + tail := itemType + + if m.enableNesting { + head, tail = ItemTyper(itemType).Next() + } + + mod, ok := m.modByItemType[head] if !ok { return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType) } - return transform(mod, val, itemType) + return transform(mod, val, tail) } var _ Modifier = &byItemTypeModifier{} diff --git a/pkg/codec/byte_string_modifier.go b/pkg/codec/byte_string_modifier.go index 153cc6e20..b6b6f6524 100644 --- a/pkg/codec/byte_string_modifier.go +++ b/pkg/codec/byte_string_modifier.go @@ -25,7 +25,18 @@ type AddressModifier interface { // // The fields parameter specifies which fields within a struct should be modified. The AddressModifier // is injected into the modifier to handle chain-specific logic during the contractReader relayer configuration. -func NewAddressBytesToStringModifier(fields []string, modifier AddressModifier) Modifier { +func NewAddressBytesToStringModifier( + fields []string, + modifier AddressModifier, +) Modifier { + return NewPathTraverseAddressBytesToStringModifier(fields, modifier, false) +} + +func NewPathTraverseAddressBytesToStringModifier( + fields []string, + modifier AddressModifier, + enablePathTraverse bool, +) Modifier { // bool is a placeholder value fieldMap := map[string]bool{} for _, field := range fields { @@ -35,9 +46,10 @@ func NewAddressBytesToStringModifier(fields []string, modifier AddressModifier) m := &bytesToStringModifier{ modifier: modifier, modifierBase: modifierBase[bool]{ - fields: fieldMap, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fieldMap, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, } @@ -60,7 +72,7 @@ type bytesToStringModifier struct { modifierBase[bool] } -func (t *bytesToStringModifier) RetypeToOffChain(onChainType reflect.Type, _ string) (tpe reflect.Type, err error) { +func (m *bytesToStringModifier) RetypeToOffChain(onChainType reflect.Type, _ string) (tpe reflect.Type, err error) { defer func() { // StructOf can panic if the fields are not valid if r := recover(); r != nil { @@ -70,11 +82,11 @@ func (t *bytesToStringModifier) RetypeToOffChain(onChainType reflect.Type, _ str }() // Attempt to retype using the shared functionality in modifierBase - offChainType, err := t.modifierBase.RetypeToOffChain(onChainType, "") + offChainType, err := m.modifierBase.RetypeToOffChain(onChainType, "") if err != nil { // Handle additional cases specific to bytesToStringModifier if onChainType.Kind() == reflect.Array { - addrType := reflect.ArrayOf(t.modifier.Length(), reflect.TypeOf(byte(0))) + addrType := reflect.ArrayOf(m.modifier.Length(), reflect.TypeOf(byte(0))) // Check for nested byte arrays (e.g., [n][20]byte) if onChainType.Elem() == addrType.Elem() { return reflect.ArrayOf(onChainType.Len(), reflect.TypeOf("")), nil @@ -86,16 +98,44 @@ func (t *bytesToStringModifier) RetypeToOffChain(onChainType reflect.Type, _ str } // TransformToOnChain uses the AddressModifier for string-to-address conversion. -func (t *bytesToStringModifier) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, noop, stringToAddressHookForOnChain(t.modifier)) +func (m *bytesToStringModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { + offChainValue, itemType, err := m.modifierBase.selectType(offChainValue, m.offChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := transformWithMaps(offChainValue, m.offToOnChainType, m.fields, noop, stringToAddressHookForOnChain(m.modifier)) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } // TransformToOffChain uses the AddressModifier for address-to-string conversion. -func (t *bytesToStringModifier) TransformToOffChain(onChainValue any, _ string) (any, error) { - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, - addressTransformationAction(t.modifier.Length()), - addressToStringHookForOffChain(t.modifier), +func (m *bytesToStringModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { + onChainValue, itemType, err := m.modifierBase.selectType(onChainValue, m.onChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := transformWithMaps(onChainValue, m.onToOffChainType, m.fields, + addressTransformationAction(m.modifier.Length()), + addressToStringHookForOffChain(m.modifier), ) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } // addressTransformationAction performs conversions over the fields we want to modify. diff --git a/pkg/codec/config.go b/pkg/codec/config.go index 21b6cca04..21105970b 100644 --- a/pkg/codec/config.go +++ b/pkg/codec/config.go @@ -106,7 +106,8 @@ type ModifierConfig interface { // The casing of the first character is ignored to allow compatibility // of go convention for public fields and on-chain names. type RenameModifierConfig struct { - Fields map[string]string + Fields map[string]string + EnablePathTraverse bool } func (r *RenameModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { @@ -114,7 +115,8 @@ func (r *RenameModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Mod delete(r.Fields, k) r.Fields[upperFirstCharacter(k)] = upperFirstCharacter(v) } - return NewRenamer(r.Fields), nil + + return NewPathTraverseRenamer(r.Fields, r.EnablePathTraverse), nil } func (r *RenameModifierConfig) MarshalJSON() ([]byte, error) { @@ -130,7 +132,8 @@ func (r *RenameModifierConfig) MarshalJSON() ([]byte, error) { // For example, if a struct has fields A and B, and you want to rename A to B, // then you need to either also rename B or drop it. type DropModifierConfig struct { - Fields []string + Fields []string + EnablePathTraverse bool } func (d *DropModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { @@ -140,7 +143,7 @@ func (d *DropModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modif fields[upperFirstCharacter(f)] = fmt.Sprintf("dropFieldPrivateName%d", i) } - return NewRenamer(fields), nil + return NewPathTraverseRenamer(fields, d.EnablePathTraverse), nil } func (d *DropModifierConfig) MarshalJSON() ([]byte, error) { @@ -171,8 +174,9 @@ func (e *ElementExtractorModifierConfig) MarshalJSON() ([]byte, error) { // HardCodeModifierConfig is used to hard code values into the map. // Note that hard-coding values will override other values. type HardCodeModifierConfig struct { - OnChainValues map[string]any - OffChainValues map[string]any + OnChainValues map[string]any + OffChainValues map[string]any + EnablePathTraverse bool } func (h *HardCodeModifierConfig) ToModifier(onChainHooks ...mapstructure.DecodeHookFunc) (Modifier, error) { @@ -193,7 +197,7 @@ func (h *HardCodeModifierConfig) ToModifier(onChainHooks ...mapstructure.DecodeH mapKeyToUpperFirst(h.OnChainValues) mapKeyToUpperFirst(h.OffChainValues) - return NewHardCoder(h.OnChainValues, h.OffChainValues, onChainHooks...) + return NewPathTraverseHardCoder(h.OnChainValues, h.OffChainValues, h.EnablePathTraverse, onChainHooks...) } func (h *HardCodeModifierConfig) MarshalJSON() ([]byte, error) { @@ -246,7 +250,8 @@ type PreCodecModifierConfig struct { // If the path leads to an array, encoding will occur on every entry. // // Example: "a.b" -> "uint256 Value" - Fields map[string]string + Fields map[string]string + EnablePathTraverse bool // Codecs is skipped in JSON serialization, it will be injected later. // The map should be keyed using the value from "Fields" to a corresponding Codec that can encode/decode for it // This allows encoding and decoding implementations to be handled outside of the modifier. @@ -256,7 +261,7 @@ type PreCodecModifierConfig struct { } func (c *PreCodecModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { - return NewPreCodec(c.Fields, c.Codecs) + return NewPathTraversePreCodec(c.Fields, c.Codecs, c.EnablePathTraverse) } func (c *PreCodecModifierConfig) MarshalJSON() ([]byte, error) { @@ -268,14 +273,15 @@ func (c *PreCodecModifierConfig) MarshalJSON() ([]byte, error) { // EpochToTimeModifierConfig is used to convert epoch seconds as uint64 fields on-chain to time.Time type EpochToTimeModifierConfig struct { - Fields []string + Fields []string + EnablePathTraverse bool } func (e *EpochToTimeModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { for i, f := range e.Fields { e.Fields[i] = upperFirstCharacter(f) } - return NewEpochToTimeModifier(e.Fields), nil + return NewPathTraverseEpochToTimeModifier(e.Fields, e.EnablePathTraverse), nil } func (e *EpochToTimeModifierConfig) MarshalJSON() ([]byte, error) { @@ -303,13 +309,14 @@ func (c *PropertyExtractorConfig) MarshalJSON() ([]byte, error) { // AddressBytesToStringModifierConfig is used to transform address byte fields into string fields. // It holds the list of fields that should be modified and the chain-specific logic to do the modifications. type AddressBytesToStringModifierConfig struct { - Fields []string + Fields []string + EnablePathTraverse bool // Modifier is skipped in JSON serialization, will be injected later. Modifier AddressModifier `json:"-"` } func (c *AddressBytesToStringModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { - return NewAddressBytesToStringModifier(c.Fields, c.Modifier), nil + return NewPathTraverseAddressBytesToStringModifier(c.Fields, c.Modifier, c.EnablePathTraverse), nil } func (c *AddressBytesToStringModifierConfig) MarshalJSON() ([]byte, error) { @@ -374,7 +381,8 @@ func (c *AddressBytesToStringModifierConfig) MarshalJSON() ([]byte, error) { type WrapperModifierConfig struct { // Fields key defines the fields to be wrapped and the name of the wrapper struct. // The field becomes a subfield of the wrapper struct where the name of the subfield is map value. - Fields map[string]string + Fields map[string]string + EnablePathTraverse bool } func (r *WrapperModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Modifier, error) { @@ -383,7 +391,7 @@ func (r *WrapperModifierConfig) ToModifier(_ ...mapstructure.DecodeHookFunc) (Mo // using a private variable will make the field not serialize, essentially dropping the field fields[upperFirstCharacter(f)] = fmt.Sprintf("dropFieldPrivateName-%s", i) } - return NewWrapperModifier(r.Fields), nil + return NewPathTraverseWrapperModifier(r.Fields, r.EnablePathTraverse), nil } func (r *WrapperModifierConfig) MarshalJSON() ([]byte, error) { diff --git a/pkg/codec/encodings/struct.go b/pkg/codec/encodings/struct.go index 946936c45..84f359785 100644 --- a/pkg/codec/encodings/struct.go +++ b/pkg/codec/encodings/struct.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" + "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" ) @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { sfs := make([]reflect.StructField, len(fields)) codecFields := make([]TypeCodec, len(fields)) + lookup := make(map[string]int) + for i, field := range fields { ft := field.Codec.GetType() if ft.Kind() != reflect.Pointer { @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) { Name: field.Name, Type: ft, } + codecFields[i] = field.Codec + lookup[field.Name] = i } return &structCodec{ - fields: codecFields, - tpe: reflect.PointerTo(reflect.StructOf(sfs)), + fields: codecFields, + fieldLookup: lookup, + tpe: reflect.PointerTo(reflect.StructOf(sfs)), }, nil } type structCodec struct { - fields []TypeCodec - tpe reflect.Type + fields []TypeCodec + fieldLookup map[string]int + tpe reflect.Type } func (s *structCodec) Encode(value any, into []byte) ([]byte, error) { @@ -113,3 +120,34 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) { } return size, nil } + +func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) { + // itemType could recurse into nested structs + fieldName, tail := codec.ItemTyper(itemType).Next() + if fieldName == "" { + return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType) + } + + idx, ok := s.fieldLookup[fieldName] + if !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + if idx >= len(s.fields) { + return nil, fmt.Errorf("%w: field index out of range for type %s; cannot access field value", types.ErrInvalidType, itemType) + } + + fieldCodec := s.fields[idx] + + // if itemType wasn't referencing a nested field + if tail == "" { + return fieldCodec, nil + } + + structType, ok := fieldCodec.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) +} diff --git a/pkg/codec/encodings/struct_test.go b/pkg/codec/encodings/struct_test.go index 0a9ace59c..f678d0aa8 100644 --- a/pkg/codec/encodings/struct_test.go +++ b/pkg/codec/encodings/struct_test.go @@ -14,6 +14,10 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) +type fieldCodec interface { + FieldCodec(string) (encodings.TypeCodec, error) +} + func TestStructCodec(t *testing.T) { t.Parallel() t.Run("NewStructCodec returns an error if names are repeated", func(t *testing.T) { @@ -176,6 +180,17 @@ func TestStructCodec(t *testing.T) { _, err := structCodecWithErr.SizeAtTopLevel(100) assert.Equal(t, errCodec.Err, err) }) + + t.Run("FieldCodec returns a nested field codec", func(t *testing.T) { + fc, ok := structCodec.(fieldCodec) + + require.True(t, ok) + + tc, err := fc.FieldCodec("Bar") + + require.NoError(t, err) + assert.Equal(t, reflect.PointerTo(reflect.TypeOf(uint64(0))), tc.GetType()) + }) } func toPointer[T any](t T) *T { diff --git a/pkg/codec/encodings/type_codec.go b/pkg/codec/encodings/type_codec.go index 1807df8c1..79d08db30 100644 --- a/pkg/codec/encodings/type_codec.go +++ b/pkg/codec/encodings/type_codec.go @@ -33,6 +33,11 @@ type TopLevelCodec interface { SizeAtTopLevel(numItems int) (int, error) } +type StructTypeCodec interface { + TypeCodec + FieldCodec(string) (TypeCodec, error) +} + // CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType // If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size. type CodecFromTypeCodec map[string]TypeCodec @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec var _ types.RemoteCodec = &LenientCodecFromTypeCodec{} func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } tpe := ntcwt.GetType() @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) { } func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) { - ntcwt, ok := c[itemType] - if !ok { - return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return nil, err } if item != nil { @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) } func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) { - ntcwt, ok := c[itemType] - if !ok { - return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return 0, err } if lp, ok := ntcwt.(TopLevelCodec); ok { return lp.SizeAtTopLevel(n) } + return ntcwt.Size(n) } @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into return decode(c, raw, into, itemType, false) } +func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { + return c.GetMaxEncodingSize(ctx, n, itemType) +} + func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error { - ntcwt, ok := c[itemType] - if !ok { - return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + ntcwt, err := getCodec(c, itemType) + if err != nil { + return err } + val, remaining, err := ntcwt.Decode(raw) if err != nil { return err @@ -138,6 +149,31 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil) } -func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) { - return c.GetMaxEncodingSize(ctx, n, itemType) +func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) { + // itemType could recurse into nested structs + head, tail := codec.ItemTyper(itemType).Next() + if head == "" { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + ntcwt, ok := c[head] + if !ok { + if ntcwt, ok = c[itemType]; !ok { + return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType) + } + + // in this case, the codec is structured to not have nestable keys + return ntcwt, nil + } + + if tail == "" { + return ntcwt, nil + } + + structType, ok := ntcwt.(StructTypeCodec) + if !ok { + return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType) + } + + return structType.FieldCodec(tail) } diff --git a/pkg/codec/encodings/type_codec_test.go b/pkg/codec/encodings/type_codec_test.go index 874819ff2..ee9110153 100644 --- a/pkg/codec/encodings/type_codec_test.go +++ b/pkg/codec/encodings/type_codec_test.go @@ -4,6 +4,7 @@ import ( rawbin "encoding/binary" "math" "reflect" + "strings" "testing" "github.com/smartcontractkit/libocr/bigbigendian" @@ -122,6 +123,34 @@ func TestCodecFromTypeCodecs(t *testing.T) { assert.Equal(t, singleItemSize*2, actual) }) + + t.Run("CreateType works for nested struct values and modifiers", func(t *testing.T) { + itemType := strings.Join([]string{TestItemWithConfigExtra, "AccountStruct", "Account"}, ".") + ts := CreateTestStruct(0, biit) + c := biit.GetNestableCodec(t) + + encoded, err := c.Encode(tests.Context(t), ts.AccountStruct.Account, itemType) + require.NoError(t, err) + + var actual []byte + require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType)) + + assert.Equal(t, ts.AccountStruct.Account, actual) + }) + + t.Run("CreateType works for nested struct values", func(t *testing.T) { + itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".") + ts := CreateTestStruct(0, biit) + c := biit.GetNestableCodec(t) + + encoded, err := c.Encode(tests.Context(t), ts.NestedDynamicStruct.Inner.S, itemType) + require.NoError(t, err) + + var actual string + require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType)) + + assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual) + }) } type interfaceTesterBase struct{} @@ -319,7 +348,61 @@ func (b *bigEndianInterfaceTester) GetCodec(t *testing.T) types.Codec { modCodec, err := codec.NewModifierCodec(c, byTypeMod, codec.BigIntHook) require.NoError(t, err) - _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), TestItemWithConfigExtra) + _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), "") + require.NoError(t, err) + + return modCodec +} + +func (b *bigEndianInterfaceTester) GetNestableCodec(t *testing.T) types.Codec { + testStruct := newTestStructCodec(t, binary.BigEndian()) + size, err := binary.BigEndian().Int(1) + require.NoError(t, err) + slice, err := encodings.NewSlice(testStruct, size) + require.NoError(t, err) + arr1, err := encodings.NewArray(1, testStruct) + require.NoError(t, err) + arr2, err := encodings.NewArray(2, testStruct) + require.NoError(t, err) + + ts := CreateTestStruct(0, b) + + tc := &encodings.CodecFromTypeCodec{ + TestItemType: testStruct, + TestItemSliceType: slice, + TestItemArray1Type: arr1, + TestItemArray2Type: arr2, + TestItemWithConfigExtra: testStruct, + NilType: encodings.Empty{}, + } + + require.NoError(t, err) + + var c types.RemoteCodec = tc + if b.lenient { + c = (*encodings.LenientCodecFromTypeCodec)(tc) + } + + mod, err := codec.NewPathTraverseHardCoder(map[string]any{ + "BigField": ts.BigField.String(), + "AccountStruct.Account": ts.AccountStruct.Account, + }, map[string]any{"ExtraField": AnyExtraValue}, true, codec.BigIntHook) + require.NoError(t, err) + + byTypeMod, err := codec.NewNestableByItemTypeModifier(map[string]codec.Modifier{ + TestItemType: codec.MultiModifier{}, + TestItemSliceType: codec.MultiModifier{}, + TestItemArray1Type: codec.MultiModifier{}, + TestItemArray2Type: codec.MultiModifier{}, + TestItemWithConfigExtra: mod, + NilType: codec.MultiModifier{}, + }) + require.NoError(t, err) + + modCodec, err := codec.NewModifierCodec(c, byTypeMod, codec.BigIntHook) + require.NoError(t, err) + + _, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), "") require.NoError(t, err) return modCodec diff --git a/pkg/codec/epoch_to_time.go b/pkg/codec/epoch_to_time.go index 287de807d..15ea07403 100644 --- a/pkg/codec/epoch_to_time.go +++ b/pkg/codec/epoch_to_time.go @@ -11,6 +11,10 @@ import ( // NewEpochToTimeModifier converts all fields from time.Time off-chain to int64. func NewEpochToTimeModifier(fields []string) Modifier { + return NewPathTraverseEpochToTimeModifier(fields, false) +} + +func NewPathTraverseEpochToTimeModifier(fields []string, enablePathTraverse bool) Modifier { fieldMap := map[string]bool{} for _, field := range fields { fieldMap[field] = true @@ -18,9 +22,10 @@ func NewEpochToTimeModifier(fields []string) Modifier { m := &timeToUnixModifier{ modifierBase: modifierBase[bool]{ - fields: fieldMap, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fieldMap, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, } @@ -40,14 +45,42 @@ type timeToUnixModifier struct { modifierBase[bool] } -func (t *timeToUnixModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { +func (m *timeToUnixModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { + offChainValue, itemType, err := m.modifierBase.selectType(offChainValue, m.offChainStructType, itemType) + if err != nil { + return nil, err + } + // since the hook will convert time.Time to epoch, we don't need to worry about converting them in the maps - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, noop, EpochToTimeHook, BigIntHook) + modified, err := transformWithMaps(offChainValue, m.offToOnChainType, m.fields, noop, EpochToTimeHook, BigIntHook) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } -func (t *timeToUnixModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { +func (m *timeToUnixModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { + onChainValue, itemType, err := m.modifierBase.selectType(onChainValue, m.onChainStructType, itemType) + if err != nil { + return nil, err + } + // since the hook will convert epoch to time.Time, we don't need to worry about converting them in the maps - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, noop, EpochToTimeHook, BigIntHook) + modified, err := transformWithMaps(onChainValue, m.onToOffChainType, m.fields, noop, EpochToTimeHook, BigIntHook) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } func noop(_ map[string]any, _ string, _ bool) error { diff --git a/pkg/codec/hard_coder.go b/pkg/codec/hard_coder.go index 9f946fa04..a224568af 100644 --- a/pkg/codec/hard_coder.go +++ b/pkg/codec/hard_coder.go @@ -10,10 +10,23 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) -// NewHardCoder creates a modifier that will hard-code values for on-chain and off-chain types -// The modifier will override any values of the same name, if you need an overwritten value to be used in a different field, -// NewRenamer must be used before NewHardCoder. -func NewHardCoder(onChain map[string]any, offChain map[string]any, hooks ...mapstructure.DecodeHookFunc) (Modifier, error) { +// NewHardCoder creates a modifier that will hard-code values for on-chain and off-chain types. The modifier will +// override any values of the same name, if you need an overwritten value to be used in a different field. NewRenamer +// must be used before NewHardCoder. +func NewHardCoder( + onChain map[string]any, + offChain map[string]any, + hooks ...mapstructure.DecodeHookFunc, +) (Modifier, error) { + return NewPathTraverseHardCoder(onChain, offChain, false, hooks...) +} + +func NewPathTraverseHardCoder( + onChain map[string]any, + offChain map[string]any, + enablePathTraverse bool, + hooks ...mapstructure.DecodeHookFunc, +) (Modifier, error) { if err := verifyHardCodeKeys(onChain); err != nil { return nil, err } else if err = verifyHardCodeKeys(offChain); err != nil { @@ -26,9 +39,10 @@ func NewHardCoder(onChain map[string]any, offChain map[string]any, hooks ...maps m := &onChainHardCoder{ modifierBase: modifierBase[any]{ - fields: offChain, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: offChain, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, onChain: onChain, hooks: myHooks, @@ -81,15 +95,58 @@ func verifyHardCodeKeys(values map[string]any) error { return nil } -func (o *onChainHardCoder) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, o.offToOnChainType, o.onChain, hardCode, o.hooks...) +// TransformToOnChain will apply the hard-code modifier and hooks on the value identified by itemType. If path traverse +// is not enabled, itemType is ignored. +// +// For path-traversal, the itemType may reference a field that does not exist in the off-chain type, but is being added +// by the hard-code modifier. Ex. offChain A.B (does not have 'C') and onChain A.B.C ('C' gets hard-coded); if 'C' is +// intended to be the result of the transformation, itemType must be A.B.C even though the off-chain type does not have +// field 'C'. +func (m *onChainHardCoder) TransformToOnChain(offChainValue any, itemType string) (any, error) { + offChainValue, itemType, err := m.modifierBase.selectType(offChainValue, m.offChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := transformWithMaps(offChainValue, m.offToOnChainType, m.onChain, hardCode, m.hooks...) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } -func (o *onChainHardCoder) TransformToOffChain(onChainValue any, _ string) (any, error) { - allHooks := make([]mapstructure.DecodeHookFunc, len(o.hooks)+1) - copy(allHooks, o.hooks) - allHooks[len(o.hooks)] = hardCodeManyHook - return transformWithMaps(onChainValue, o.onToOffChainType, o.fields, hardCode, allHooks...) +// TransformToOffChain will apply the hard-code modifier and hooks on the value identified by itemType. If path traverse +// is not enabled, itemType is ignored. +// +// For path-traversal, the itemType may reference a field that does not exist in the on-chain type, but is being added +// by the hard-code modifier. Ex. on-chain A.B (does not have 'C') and off-chain A.B.C ('C' gets hard-coded); if 'C' is +// intended to be the result of the transformation, itemType must be A.B.C even though the on-chain type does not have +// field 'C'. +func (m *onChainHardCoder) TransformToOffChain(onChainValue any, itemType string) (any, error) { + onChainValue, itemType, err := m.modifierBase.selectType(onChainValue, m.onChainStructType, itemType) + if err != nil { + return nil, err + } + + allHooks := make([]mapstructure.DecodeHookFunc, len(m.hooks)+1) + copy(allHooks, m.hooks) + allHooks[len(m.hooks)] = hardCodeManyHook + + modified, err := transformWithMaps(onChainValue, m.onToOffChainType, m.fields, hardCode, allHooks...) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } func hardCode(extractMap map[string]any, key string, item any) error { diff --git a/pkg/codec/hard_coder_test.go b/pkg/codec/hard_coder_test.go index 6dc1ba0ab..13b413c95 100644 --- a/pkg/codec/hard_coder_test.go +++ b/pkg/codec/hard_coder_test.go @@ -469,6 +469,52 @@ func TestHardCoder(t *testing.T) { require.NoError(t, err) assert.Equal(t, int32(123), reflect.ValueOf(offChain).FieldByName("B").Interface()) }) + + t.Run("TransformToOnChain and TransformToOffChain works for itemType path", func(t *testing.T) { + nestedHardCoder, err := codec.NewPathTraverseHardCoder(map[string]any{ + "A": "Top", + "B.A": "Foo", + "B.C": []int32{2, 3}, + "C.A": "Foo", + "C.C": []int32{2, 3}, + }, map[string]any{ + "B.Z": "Bar", + "B.Q": []struct { + A int + B string + }{{1, "a"}, {2, "b"}}, + "C.Z": "Bar", + "C.Q": []struct { + A int + B string + }{{1, "a"}, {2, "b"}}, + }, true) + require.NoError(t, err) + + offChainType, err := nestedHardCoder.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "") + require.NoError(t, err) + + _, err = nestedHardCoder.RetypeToOffChain(reflect.TypeOf(""), "B.A") + require.NoError(t, err) + + iInput := reflect.Indirect(reflect.New(offChainType)) + iB := iInput.FieldByName("B") + iB.FieldByName("B").SetInt(1) + iC := iInput.FieldByName("C") + iC.Set(reflect.MakeSlice(iC.Type(), 2, 2)) + iC.Index(0).FieldByName("B").SetInt(2) + iC.Index(1).FieldByName("B").SetInt(3) + iInput.FieldByName("D").SetInt(1) + + actual, err := nestedHardCoder.TransformToOnChain(iInput.FieldByName("B").FieldByName("A").Interface(), "B.A") + require.NoError(t, err) + + expected := "Foo" + assert.Equal(t, expected, actual) + + _, err = nestedHardCoder.TransformToOffChain(expected, "B.A") + require.NoError(t, err) + }) } // Since we're using the on-chain values that have their hard-coded values set to diff --git a/pkg/codec/modifier.go b/pkg/codec/modifier.go index a25a44599..da3e7eda3 100644 --- a/pkg/codec/modifier.go +++ b/pkg/codec/modifier.go @@ -7,13 +7,25 @@ import ( // Modifier allows you to modify the off-chain type to be used on-chain, and vice-versa. // A modifier is set up by retyping the on-chain type to a type used off-chain. type Modifier interface { + // RetypeToOffChain will retype the onChainType to its correlated offChainType. The itemType should be empty for an + // expected whole struct. A dot-separated string can be provided when path traversal is supported on the modifier + // to retype a nested field. + // + // For most modifiers, RetypeToOffChain must be called first with the entire struct to be retyped/modified before + // any other transformations or path traversal can function. RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) // TransformToOnChain transforms a type returned from AdjustForInput into the outputType. // You may also pass a pointer to the type returned by AdjustForInput to get a pointer to outputType. + // + // Modifiers should also optionally provide support for path traversal using itemType. In the case of using path + // traversal, the offChainValue should be the field value being modified as identified by itemType. TransformToOnChain(offChainValue any, itemType string) (any, error) // TransformToOffChain is the reverse of TransformForOnChain input. // It is used to send back the object after it has been decoded + // + // Modifiers should also optionally provide support for path traversal using itemType. In the case of using path + // traversal, the onChainValue should be the field value being modified as identified by itemType. TransformToOffChain(onChainValue any, itemType string) (any, error) } diff --git a/pkg/codec/modifier_base.go b/pkg/codec/modifier_base.go index 8a092fe9b..d1d939f03 100644 --- a/pkg/codec/modifier_base.go +++ b/pkg/codec/modifier_base.go @@ -12,14 +12,22 @@ import ( ) type modifierBase[T any] struct { + enablePathTraverse bool fields map[string]T onToOffChainType map[reflect.Type]reflect.Type offToOnChainType map[reflect.Type]reflect.Type modifyFieldForInput func(pkgPath string, outputField *reflect.StructField, fullPath string, change T) error addFieldForInput func(pkgPath, name string, change T) reflect.StructField + onChainStructType reflect.Type + offChainStructType reflect.Type } +// RetypeToOffChain sets the on-chain and off-chain types for modifications. If itemType is empty, the type returned +// will be the full off-chain type and all type mappings will be reset. If itemType is not empty, retyping assumes a +// sub-field is expected and the off-chain type of the sub-field is returned with no modifications to internal type +// mappings. func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType string) (tpe reflect.Type, err error) { + // onChainType could be the entire struct or a sub-field type defer func() { // StructOf can panic if the fields are not valid if r := recover(); r != nil { @@ -27,48 +35,85 @@ func (m *modifierBase[T]) RetypeToOffChain(onChainType reflect.Type, itemType st err = fmt.Errorf("%w: %v", types.ErrInvalidType, r) } }() + + // path traverse allows an item type of Struct.FieldA.NestedField to isolate modifiers + // associated with the nested field `NestedField`. + if !m.enablePathTraverse { + itemType = "" + } + + // if itemType is empty, store the type mappings + // if itemType is not empty, assume a sub-field property is expected to be extracted + onChainStructType := onChainType + if itemType != "" { + onChainStructType = m.onChainStructType + } + + // this will only work for the full on-chain struct type unless we cache the individual + // field types too. + if cached, ok := m.onToOffChainType[onChainStructType]; ok { + return typeForPath(cached, itemType) + } + if len(m.fields) == 0 { m.offToOnChainType[onChainType] = onChainType m.onToOffChainType[onChainType] = onChainType - return onChainType, nil - } + m.onChainStructType = onChainType + m.offChainStructType = onChainType - if cached, ok := m.onToOffChainType[onChainType]; ok { - return cached, nil + return typeForPath(onChainType, itemType) } var offChainType reflect.Type - switch onChainType.Kind() { + + // the onChainStructType here should always reference the full on-chain struct type + switch onChainStructType.Kind() { case reflect.Pointer: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), itemType); err != nil { return nil, err } offChainType = reflect.PointerTo(elm) case reflect.Slice: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } offChainType = reflect.SliceOf(elm) case reflect.Array: - elm, err := m.RetypeToOffChain(onChainType.Elem(), "") - if err != nil { + var elm reflect.Type + + if elm, err = m.RetypeToOffChain(onChainStructType.Elem(), ""); err != nil { return nil, err } - offChainType = reflect.ArrayOf(onChainType.Len(), elm) + offChainType = reflect.ArrayOf(onChainStructType.Len(), elm) case reflect.Struct: - return m.getStructType(onChainType) + if offChainType, err = m.getStructType(onChainStructType); err != nil { + return nil, err + } default: - return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainType.Kind()) + // if the types don't match, it means we are attempting to traverse the main struct + if onChainType != m.onChainStructType { + return onChainType, nil + } + + return nil, fmt.Errorf("%w: cannot retype the kind %v", types.ErrInvalidType, onChainStructType.Kind()) + } + + m.onToOffChainType[onChainStructType] = offChainType + m.offToOnChainType[offChainType] = onChainStructType + + if m.onChainStructType == nil { + m.onChainStructType = onChainType + m.offChainStructType = offChainType } - m.onToOffChainType[onChainType] = offChainType - m.offToOnChainType[offChainType] = onChainType - return offChainType, nil + return typeForPath(offChainType, itemType) } func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, error) { @@ -78,10 +123,11 @@ func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, } for _, key := range m.subkeysFirst() { + curLocations := filedLocations parts := strings.Split(key, ".") fieldName := parts[len(parts)-1] + parts = parts[:len(parts)-1] - curLocations := filedLocations for _, part := range parts { if curLocations, err = curLocations.populateSubFields(part); err != nil { return nil, err @@ -102,10 +148,7 @@ func (m *modifierBase[T]) getStructType(outputType reflect.Type) (reflect.Type, } } - newStruct := filedLocations.makeNewType() - m.onToOffChainType[outputType] = newStruct - m.offToOnChainType[newStruct] = outputType - return newStruct, nil + return filedLocations.makeNewType(), nil } // subkeysFirst returns a list of keys that will always have a sub-key before the key if both are present @@ -122,6 +165,59 @@ func (m *modifierBase[T]) subkeysFirst() []string { return orderedKeys } +func (m *modifierBase[T]) onToOffChainTyper(onChainType reflect.Type, itemType string) (reflect.Type, error) { + onChainRefType := onChainType + if itemType != "" { + onChainRefType = m.onChainStructType + } + + offChainType, ok := m.onToOffChainType[onChainRefType] + if !ok { + return nil, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, onChainType) + } + + return typeForPath(offChainType, itemType) +} + +func (m *modifierBase[T]) offToOnChainTyper(offChainType reflect.Type, itemType string) (reflect.Type, error) { + offChainRefType := offChainType + if itemType != "" { + offChainRefType = m.offChainStructType + } + + onChainType, ok := m.offToOnChainType[offChainRefType] + if !ok { + return nil, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, offChainType) + } + + return typeForPath(onChainType, itemType) +} + +func (m *modifierBase[T]) selectType(inputValue any, savedType reflect.Type, itemType string) (any, string, error) { + // set itemType to an ignore value if path traversal is not enabled + if !m.enablePathTraverse { + return inputValue, "", nil + } + + // the offChainValue might be a subfield value; get the true offChainStruct type already stored and set the value + baseStructValue := inputValue + + // path traversal is expected, but offChainValue is the value of a field, not the actual struct + // create a new struct from the stored offChainStruct with the provided value applied and all other fields set to + // their zero value. + if itemType != "" { + into := reflect.New(savedType) + + if err := applyValueForPath(into, reflect.ValueOf(inputValue), itemType); err != nil { + return nil, itemType, err + } + + baseStructValue = reflect.Indirect(into).Interface() + } + + return baseStructValue, itemType, nil +} + // subkeysLast returns a list of keys that will always have a sub-key after the key if both are present func subkeysLast[T any](fields map[string]T) []string { orderedKeys := make([]string, 0, len(fields)) @@ -130,6 +226,7 @@ func subkeysLast[T any](fields map[string]T) []string { } sort.Strings(orderedKeys) + return orderedKeys } @@ -264,6 +361,117 @@ func doForMapElements[T any](valueMapping map[string]any, fields map[string]T, f return nil } +func typeForPath(from reflect.Type, itemType string) (reflect.Type, error) { + if itemType == "" { + return from, nil + } + + switch from.Kind() { + case reflect.Pointer: + elem, err := typeForPath(from.Elem(), itemType) + if err != nil { + return nil, err + } + + return elem, nil + case reflect.Array, reflect.Slice: + return nil, fmt.Errorf("%w: cannot extract a field from an array or slice", types.ErrInvalidType) + case reflect.Struct: + head, tail := ItemTyper(itemType).Next() + + field, ok := from.FieldByName(head) + if !ok { + return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType) + } + + if tail == "" { + return field.Type, nil + } + + return typeForPath(field.Type, tail) + default: + return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind()) + } +} + +func valueForPath(from reflect.Value, itemType string) (any, error) { + if itemType == "" { + return from.Interface(), nil + } + + switch from.Kind() { + case reflect.Pointer: + elem, err := valueForPath(from.Elem(), itemType) + if err != nil { + return nil, err + } + + return elem, nil + case reflect.Array, reflect.Slice: + return nil, fmt.Errorf("%w: cannot extract a field from an array or slice", types.ErrInvalidType) + case reflect.Struct: + head, tail := ItemTyper(itemType).Next() + + field := from.FieldByName(head) + if !field.IsValid() { + return nil, fmt.Errorf("%w: field not found for path %s and itemType %s", types.ErrInvalidType, from, itemType) + } + + if tail == "" { + return field.Interface(), nil + } + + return valueForPath(field, tail) + default: + return nil, fmt.Errorf("%w: cannot extract a field from kind %s", types.ErrInvalidType, from.Kind()) + } +} + +func applyValueForPath(vInto, vField reflect.Value, itemType string) error { + switch vInto.Kind() { + case reflect.Pointer: + if !vInto.Elem().IsValid() { + into := reflect.New(vInto.Type().Elem()) + + vInto.Set(into) + } + + err := applyValueForPath(vInto.Elem(), vField, itemType) + if err != nil { + return err + } + + return nil + case reflect.Array, reflect.Slice: + return fmt.Errorf("%w: cannot set a field from an array or slice", types.ErrInvalidType) + case reflect.Struct: + head, tail := ItemTyper(itemType).Next() + + field := vInto.FieldByName(head) + if !field.IsValid() { + return fmt.Errorf("%w: invalid field for type %s and name %s", types.ErrInvalidType, vInto, head) + } + + if tail == "" { + if field.Type() != vField.Type() { + return fmt.Errorf("%w: value type mismatch for field %s", types.ErrInvalidType, head) + } + + if !field.CanSet() { + return fmt.Errorf("%w: cannot set field %s", types.ErrInvalidType, head) + } + + field.Set(vField) + + return nil + } + + return applyValueForPath(field, vField, tail) + default: + return fmt.Errorf("%w: cannot set a field from kind %s", types.ErrInvalidType, vInto.Kind()) + } +} + type PathMappingError struct { Err error Path string @@ -276,3 +484,18 @@ func (e PathMappingError) Error() string { func (e PathMappingError) Cause() error { return e.Err } + +type ItemTyper string + +func (t ItemTyper) Next() (string, string) { + if string(t) == "" { + return "", "" + } + + path := strings.Split(string(t), ".") + if len(path) == 1 { + return path[0], "" + } + + return path[0], strings.Join(path[1:], ".") +} diff --git a/pkg/codec/precodec.go b/pkg/codec/precodec.go index de5dec055..cbbd4e7c2 100644 --- a/pkg/codec/precodec.go +++ b/pkg/codec/precodec.go @@ -9,14 +9,28 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) -// PreCodec creates a modifier that will run a preliminary encoding/decoding step. -// This is useful when wanting to move nested data as generic bytes. -func NewPreCodec(fields map[string]string, codecs map[string]types.RemoteCodec) (Modifier, error) { +// NewPreCodec creates a modifier that will run a preliminary encoding/decoding step. This is useful when wanting to +// move nested data as generic bytes. +func NewPreCodec( + fields map[string]string, + codecs map[string]types.RemoteCodec, +) (Modifier, error) { + return NewPathTraversePreCodec(fields, codecs, false) +} + +// NewPathTraversePreCodec creates a PreCodec modifier with itemType path traversal enabled or disabled. The standard +// constructor. NewPreCodec has path traversal off by default. +func NewPathTraversePreCodec( + fields map[string]string, + codecs map[string]types.RemoteCodec, + enablePathTraverse bool, +) (Modifier, error) { m := &preCodec{ modifierBase: modifierBase[string]{ - fields: fields, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fields, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, codecs: codecs, } @@ -56,11 +70,41 @@ type preCodec struct { codecs map[string]types.RemoteCodec } -func (pc *preCodec) TransformToOffChain(onChainValue any, _ string) (any, error) { +func (pc *preCodec) TransformToOffChain(onChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !pc.modifierBase.enablePathTraverse { + itemType = "" + } + allHooks := make([]mapstructure.DecodeHookFunc, 1) allHooks[0] = hardCodeManyHook - return transformWithMaps(onChainValue, pc.onToOffChainType, pc.fields, pc.decodeFieldMapAction, allHooks...) + // the offChainValue might be a subfield value; get the true offChainStruct type already stored and set the value + onChainStructValue := onChainValue + + // path traversal is expected, but offChainValue is the value of a field, not the actual struct + // create a new struct from the stored offChainStruct with the provided value applied and all other fields set to + // their zero value. + if itemType != "" { + into := reflect.New(pc.onChainStructType) + + if err := applyValueForPath(into, reflect.ValueOf(onChainValue), itemType); err != nil { + return nil, err + } + + onChainStructValue = reflect.Indirect(into).Interface() + } + + modified, err := transformWithMaps(onChainStructValue, pc.onToOffChainType, pc.fields, pc.decodeFieldMapAction, allHooks...) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } func (pc *preCodec) decodeFieldMapAction(extractMap map[string]any, key string, typeDef string) error { @@ -86,11 +130,41 @@ func (pc *preCodec) decodeFieldMapAction(extractMap map[string]any, key string, return nil } -func (pc *preCodec) TransformToOnChain(offChainValue any, _ string) (any, error) { +func (pc *preCodec) TransformToOnChain(offChainValue any, itemType string) (any, error) { allHooks := make([]mapstructure.DecodeHookFunc, 1) allHooks[0] = hardCodeManyHook - return transformWithMaps(offChainValue, pc.offToOnChainType, pc.fields, pc.encodeFieldMapAction, allHooks...) + // set itemType to an ignore value if path traversal is not enabled + if !pc.modifierBase.enablePathTraverse { + itemType = "" + } + + // the offChainValue might be a subfield value; get the true offChainStruct type already stored and set the value + offChainStructValue := offChainValue + + // path traversal is expected, but offChainValue is the value of a field, not the actual struct + // create a new struct from the stored offChainStruct with the provided value applied and all other fields set to + // their zero value. + if itemType != "" { + into := reflect.New(pc.offChainStructType) + + if err := applyValueForPath(into, reflect.ValueOf(offChainValue), itemType); err != nil { + return nil, err + } + + offChainStructValue = reflect.Indirect(into).Interface() + } + + modified, err := transformWithMaps(offChainStructValue, pc.offToOnChainType, pc.fields, pc.encodeFieldMapAction, allHooks...) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } func (pc *preCodec) encodeFieldMapAction(extractMap map[string]any, key string, typeDef string) error { diff --git a/pkg/codec/renamer.go b/pkg/codec/renamer.go index b2414964a..fed6b3f53 100644 --- a/pkg/codec/renamer.go +++ b/pkg/codec/renamer.go @@ -3,17 +3,23 @@ package codec import ( "fmt" "reflect" + "strings" "unicode" "github.com/smartcontractkit/chainlink-common/pkg/types" ) func NewRenamer(fields map[string]string) Modifier { + return NewPathTraverseRenamer(fields, false) +} + +func NewPathTraverseRenamer(fields map[string]string, enablePathTraverse bool) Modifier { m := &renamer{ modifierBase: modifierBase[string]{ - fields: fields, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fields, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, } m.modifyFieldForInput = func(pkgPath string, field *reflect.StructField, _, newName string) error { @@ -30,26 +36,80 @@ type renamer struct { modifierBase[string] } -func (r *renamer) TransformToOffChain(onChainValue any, _ string) (any, error) { - rOutput, err := renameTransform(r.onToOffChainType, reflect.ValueOf(onChainValue)) +func (r *renamer) TransformToOffChain(onChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !r.modifierBase.enablePathTraverse { + itemType = "" + } + + // itemType references the on-chain type + // rename field/subfield path in itemType to match the modifier renaming + if itemType != "" { + var ref string + + parts := strings.Split(itemType, ".") + if len(parts) > 0 { + ref = parts[len(parts)-1] + } + + for on, off := range r.fields { + if ref == on { + // B.A -> C == B.C + parts[len(parts)-1] = off + itemType = strings.Join(parts, ".") + + break + } + } + } + + rOutput, err := renameTransform(r.onToOffChainTyper, reflect.ValueOf(onChainValue), itemType) if err != nil { return nil, err } + return rOutput.Interface(), nil } -func (r *renamer) TransformToOnChain(offChainValue any, _ string) (any, error) { - rOutput, err := renameTransform(r.offToOnChainType, reflect.ValueOf(offChainValue)) +func (r *renamer) TransformToOnChain(offChainValue any, itemType string) (any, error) { + // set itemType to an ignore value if path traversal is not enabled + if !r.modifierBase.enablePathTraverse { + itemType = "" + } + + if itemType != "" { + var ref string + + parts := strings.Split(itemType, ".") + if len(parts) > 0 { + ref = parts[len(parts)-1] + } + + for on, off := range r.fields { + if ref == off { + itemType = on + + break + } + } + } + + rOutput, err := renameTransform(r.offToOnChainTyper, reflect.ValueOf(offChainValue), itemType) if err != nil { return nil, err } + return rOutput.Interface(), nil } -func renameTransform(typeMap map[reflect.Type]reflect.Type, rInput reflect.Value) (reflect.Value, error) { - toType, ok := typeMap[rInput.Type()] - if !ok { - return reflect.Value{}, fmt.Errorf("%w: cannot rename unknown type %v", types.ErrInvalidType, toType) +func renameTransform( + typeFunc func(reflect.Type, string) (reflect.Type, error), + rInput reflect.Value, + itemType string, +) (reflect.Value, error) { + toType, err := typeFunc(rInput.Type(), itemType) + if err != nil { + return reflect.Value{}, err } if toType == rInput.Type() { @@ -70,6 +130,10 @@ func transformNonPointer(toType reflect.Type, rInput reflect.Value) (reflect.Val // make sure the input is addressable ptr := reflect.New(rInput.Type()) reflect.Indirect(ptr).Set(rInput) + + // UnsafePointer is a bit of a Go hack but works because the data types/structure and data for the two types + // are the same. The only change is the names of the fields. changed := reflect.NewAt(toType, ptr.UnsafePointer()).Elem() + return changed, nil } diff --git a/pkg/codec/renamer_test.go b/pkg/codec/renamer_test.go index 55453ff16..9fbc170c8 100644 --- a/pkg/codec/renamer_test.go +++ b/pkg/codec/renamer_test.go @@ -28,9 +28,9 @@ func TestRenamer(t *testing.T) { D string } - renamer := codec.NewRenamer(map[string]string{"A": "X", "C": "Z"}) - invalidRenamer := codec.NewRenamer(map[string]string{"W": "X", "C": "Z"}) - nestedRenamer := codec.NewRenamer(map[string]string{"A": "X", "B.A": "X", "B.C": "Z", "C.A": "X", "C.C": "Z", "B": "Y"}) + renamer := codec.NewPathTraverseRenamer(map[string]string{"A": "X", "C": "Z"}, true) + invalidRenamer := codec.NewPathTraverseRenamer(map[string]string{"W": "X", "C": "Z"}, true) + nestedRenamer := codec.NewPathTraverseRenamer(map[string]string{"A": "X", "B.A": "X", "B.C": "Z", "C.A": "X", "C.C": "Z", "B": "Y"}, true) t.Run("RetypeToOffChain renames fields keeping structure", func(t *testing.T) { offChainType, err := renamer.RetypeToOffChain(reflect.TypeOf(testStruct{}), "") require.NoError(t, err) @@ -385,6 +385,45 @@ func TestRenamer(t *testing.T) { require.NoError(t, err) assert.Equal(t, iOffchain.Interface(), newInput) }) + + t.Run("TransformToOnChain and TransformToOffChain works on nested fields even if the field itself is renamed for path", func(t *testing.T) { + offChainType, err := nestedRenamer.RetypeToOffChain(reflect.TypeOf(nestedTestStruct{}), "") + require.NoError(t, err) + iOffchain := reflect.Indirect(reflect.New(offChainType)) + + iOffchain.FieldByName("X").SetString("foo") + rY := iOffchain.FieldByName("Y") + rY.FieldByName("X").SetString("foo") + rY.FieldByName("B").SetInt(10) + rY.FieldByName("Z").SetInt(20) + + rC := iOffchain.FieldByName("C") + rC.Set(reflect.MakeSlice(rC.Type(), 2, 2)) + iElm := rC.Index(0) + iElm.FieldByName("X").SetString("foo") + iElm.FieldByName("B").SetInt(10) + iElm.FieldByName("Z").SetInt(20) + iElm = rC.Index(1) + iElm.FieldByName("X").SetString("baz") + iElm.FieldByName("B").SetInt(15) + iElm.FieldByName("Z").SetInt(25) + + iOffchain.FieldByName("D").SetString("bar") + + output, err := nestedRenamer.TransformToOnChain(iOffchain.FieldByName("Y").Interface(), "Y") + + require.NoError(t, err) + + expected := testStruct{ + A: "foo", + B: 10, + C: 20, + } + assert.Equal(t, expected, output) + newInput, err := nestedRenamer.TransformToOffChain(expected, "B") + require.NoError(t, err) + assert.Equal(t, iOffchain.FieldByName("Y").Interface(), newInput) + }) } func assertBasicRenameTransform(t *testing.T, offChainType reflect.Type) { diff --git a/pkg/codec/wrapper.go b/pkg/codec/wrapper.go index dd1061244..fb6ad0a3c 100644 --- a/pkg/codec/wrapper.go +++ b/pkg/codec/wrapper.go @@ -5,12 +5,18 @@ import ( "reflect" ) +// NewWrapperModifier creates a modifier that will wrap specified on-chain fields in a struct. func NewWrapperModifier(fields map[string]string) Modifier { + return NewPathTraverseWrapperModifier(fields, false) +} + +func NewPathTraverseWrapperModifier(fields map[string]string, enablePathTraverse bool) Modifier { m := &wrapperModifier{ modifierBase: modifierBase[string]{ - fields: fields, - onToOffChainType: map[reflect.Type]reflect.Type{}, - offToOnChainType: map[reflect.Type]reflect.Type{}, + enablePathTraverse: enablePathTraverse, + fields: fields, + onToOffChainType: map[reflect.Type]reflect.Type{}, + offToOnChainType: map[reflect.Type]reflect.Type{}, }, } @@ -29,12 +35,40 @@ type wrapperModifier struct { modifierBase[string] } -func (t *wrapperModifier) TransformToOnChain(offChainValue any, _ string) (any, error) { - return transformWithMaps(offChainValue, t.offToOnChainType, t.fields, unwrapFieldMapAction) +func (m *wrapperModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) { + offChainValue, itemType, err := m.modifierBase.selectType(offChainValue, m.offChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := transformWithMaps(offChainValue, m.offToOnChainType, m.fields, unwrapFieldMapAction) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } -func (t *wrapperModifier) TransformToOffChain(onChainValue any, _ string) (any, error) { - return transformWithMaps(onChainValue, t.onToOffChainType, t.fields, wrapFieldMapAction) +func (m *wrapperModifier) TransformToOffChain(onChainValue any, itemType string) (any, error) { + onChainValue, itemType, err := m.modifierBase.selectType(onChainValue, m.onChainStructType, itemType) + if err != nil { + return nil, err + } + + modified, err := transformWithMaps(onChainValue, m.onToOffChainType, m.fields, wrapFieldMapAction) + if err != nil { + return nil, err + } + + if itemType != "" { + return valueForPath(reflect.ValueOf(modified), itemType) + } + + return modified, nil } func wrapFieldMapAction(typesMap map[string]any, fieldName string, wrappedFieldName string) error { diff --git a/pkg/codec/wrapper_test.go b/pkg/codec/wrapper_test.go index 11bf148b6..ea2a9c8d3 100644 --- a/pkg/codec/wrapper_test.go +++ b/pkg/codec/wrapper_test.go @@ -2,6 +2,7 @@ package codec_test import ( "errors" + "log" "reflect" "testing" @@ -12,6 +13,29 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) +func ExampleNewWrapperModifier() { + type nestedOnChain struct { + C int + } + + type onChain struct { + A string + B nestedOnChain + } + + // specify the fields to be wrapped and the name of the new field with a map + fields := map[string]string{"A": "X", "B.C": "Y"} + wrapper := codec.NewWrapperModifier(fields) + + offChainType, _ := wrapper.RetypeToOffChain(reflect.TypeOf(onChain{}), "") + + // expected off-chain type: + // struct { A struct { X string }; B struct { C struct { Y int } } } + // + // both A and B.C were wrapped in a new struct with the respective specified field names + log.Println(offChainType) +} + func TestWrapper(t *testing.T) { t.Parallel() @@ -75,6 +99,8 @@ func TestWrapper(t *testing.T) { require.NoError(t, err) assert.Equal(t, 4, offChainType.NumField()) + t.Log(offChainType) + f0 := offChainType.Field(0) f0PreRetype := reflect.TypeOf(nestedTestStruct{}).Field(0) assert.Equal(t, wrapType("X", f0PreRetype.Type).String(), f0.Type.String())