Skip to content

Commit

Permalink
fix(scale): Use *int for scale index (#3274)
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn authored May 25, 2023
1 parent bd68814 commit 9b04d30
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 36 deletions.
58 changes: 35 additions & 23 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ func Test_decodeState_decodeStruct(t *testing.T) {
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
var diff string
if tt.out != nil {
diff = cmp.Diff(dst, tt.out, cmpopts.IgnoreUnexported(tt.in))
} else {
diff = cmp.Diff(dst, tt.in, cmpopts.IgnoreUnexported(big.Int{}, tt.in, VDTValue2{}, MyStructWithIgnore{}))
}
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)

// assert response only if we aren't expecting an error
if !tt.wantErr {
var diff string
if tt.out != nil {
diff = cmp.Diff(dst, tt.out, cmpopts.IgnoreUnexported(tt.in))
} else {
diff = cmp.Diff(dst, tt.in, cmpopts.IgnoreUnexported(big.Int{}, tt.in, VDTValue2{}, MyStructWithIgnore{}))
}
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)
}
}
})
}
Expand Down Expand Up @@ -294,20 +298,24 @@ func Test_unmarshal_optionality(t *testing.T) {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
var diff string
if tt.out != nil {
diff = cmp.Diff(
reflect.ValueOf(dst).Elem().Interface(),
reflect.ValueOf(tt.out).Interface(),
cmpopts.IgnoreUnexported(tt.in))
} else {
diff = cmp.Diff(
reflect.ValueOf(dst).Elem().Interface(),
reflect.ValueOf(tt.in).Interface(),
cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{}))
}
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)

// assert response only if we aren't expecting an error
if !tt.wantErr {
var diff string
if tt.out != nil {
diff = cmp.Diff(
reflect.ValueOf(dst).Elem().Interface(),
reflect.ValueOf(tt.out).Interface(),
cmpopts.IgnoreUnexported(tt.in))
} else {
diff = cmp.Diff(
reflect.ValueOf(dst).Elem().Interface(),
reflect.ValueOf(tt.in).Interface(),
cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{}))
}
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)
}
}
}
})
Expand All @@ -325,7 +333,11 @@ func Test_unmarshal_optionality_nil_case(t *testing.T) {
// ignore out, since we are testing nil case
// out: t.out,
}
ptrTest.want = []byte{0x00}

// for error cases, we don't need to modify the input since we need it to fail
if !t.wantErr {
ptrTest.want = []byte{0x00}
}

temp := reflect.New(reflect.TypeOf(t.in))
// create a new pointer to type of temp
Expand Down
30 changes: 22 additions & 8 deletions pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,19 @@ var (
},
want: []byte{0x04, 0x01, 0x02, 0, 0, 0, 0x01},
},
{
name: "struct_{[]byte,_int32}_with_invalid_tag",
in: &struct {
Foo []byte `scale:"1,invalid"`
}{
Foo: []byte{0x01},
},
wantErr: true,
},
{
name: "struct_{[]byte,_int32,_bool}",
in: struct {
Baz bool `scale:"3,enum"`
Baz bool `scale:"3"`
Bar int32 `scale:"2"`
Foo []byte `scale:"1"`
}{
Expand Down Expand Up @@ -1073,8 +1082,12 @@ func Test_encodeState_encodeStruct(t *testing.T) {
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)

// we don't need this check for error cases
if !tt.wantErr {
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
}
}
})
}
Expand Down Expand Up @@ -1182,8 +1195,12 @@ func Test_marshal_optionality(t *testing.T) {
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)

// if we expect an error, we do not need to check the result
if !tt.wantErr {
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
}
})
}
Expand All @@ -1195,9 +1212,6 @@ func Test_marshal_optionality_nil_cases(t *testing.T) {
t := allTests[i]
ptrTest := test{
name: t.name,
// in: t.in,
wantErr: t.wantErr,
want: t.want,
}
// create a new pointer to new zero value of t.in
temp := reflect.New(reflect.TypeOf(t.in))
Expand Down
1 change: 1 addition & 0 deletions pkg/scale/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ var (
errBigIntIsNil = errors.New("big int is nil")
ErrVaryingDataTypeNotSet = errors.New("varying data type not set")
ErrUnsupportedCustomPrimitive = errors.New("unsupported type for custom primitive")
ErrInvalidScaleIndex = errors.New("invalid scale index")
)
10 changes: 8 additions & 2 deletions pkg/scale/scale.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"sync"
)
Expand All @@ -19,7 +20,7 @@ var cache = &fieldScaleIndicesCache{
// fieldScaleIndex is used to map field index to scale index
type fieldScaleIndex struct {
fieldIndex int
scaleIndex *string
scaleIndex *int
}
type fieldScaleIndices []fieldScaleIndex

Expand Down Expand Up @@ -61,9 +62,14 @@ func (fsic *fieldScaleIndicesCache) fieldScaleIndices(in interface{}) (
// ignore this field
continue
default:
scaleIndex, indexErr := strconv.Atoi(tag)
if indexErr != nil {
err = fmt.Errorf("%w: %v", ErrInvalidScaleIndex, indexErr)
return
}
indices = append(indices, fieldScaleIndex{
fieldIndex: i,
scaleIndex: &tag,
scaleIndex: &scaleIndex,
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/scale/scale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ func Test_fieldScaleIndicesCache_fieldScaleIndices(t *testing.T) {
wantIndices: fieldScaleIndices{
{
fieldIndex: 5,
scaleIndex: newStringPtr("1"),
scaleIndex: newIntPtr(1),
},
{
fieldIndex: 3,
scaleIndex: newStringPtr("2"),
scaleIndex: newIntPtr(2),
},
{
fieldIndex: 1,
scaleIndex: newStringPtr("3"),
scaleIndex: newIntPtr(3),
},
{
fieldIndex: 0,
Expand Down

0 comments on commit 9b04d30

Please sign in to comment.