Skip to content

Commit

Permalink
Add DeepSSZEqual and DeepNotSSZEqual (#8421)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xKiwi authored Feb 9, 2021
1 parent 2f98e6a commit cd3851c
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 11 deletions.
8 changes: 4 additions & 4 deletions beacon-chain/blockchain/process_block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/pkg/errors"
"github.com/prysmaticlabs/eth2-types"
types "github.com/prysmaticlabs/eth2-types"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/cache/depositcache"
"github.com/prysmaticlabs/prysm/beacon-chain/core/blocks"
Expand Down Expand Up @@ -880,11 +880,11 @@ func TestUpdateJustifiedInitSync(t *testing.T) {

require.NoError(t, service.updateJustifiedInitSync(ctx, newCp))

assert.DeepEqual(t, currentCp, service.prevJustifiedCheckpt, "Incorrect previous justified checkpoint")
assert.DeepEqual(t, newCp, service.CurrentJustifiedCheckpt(), "Incorrect current justified checkpoint in cache")
assert.DeepSSZEqual(t, currentCp, service.prevJustifiedCheckpt, "Incorrect previous justified checkpoint")
assert.DeepSSZEqual(t, newCp, service.CurrentJustifiedCheckpt(), "Incorrect current justified checkpoint in cache")
cp, err := service.beaconDB.JustifiedCheckpoint(ctx)
require.NoError(t, err)
assert.DeepEqual(t, newCp, cp, "Incorrect current justified checkpoint in db")
assert.DeepSSZEqual(t, newCp, cp, "Incorrect current justified checkpoint in db")
}

func TestHandleEpochBoundary_BadMetrics(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions beacon-chain/core/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"testing"

"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/eth2-types"
types "github.com/prysmaticlabs/eth2-types"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
Expand Down Expand Up @@ -72,8 +72,8 @@ func TestGenesisBeaconState_OK(t *testing.T) {

// Recent state checks.
assert.DeepEqual(t, make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector), newState.Slashings(), "Slashings was not correctly initialized")
assert.DeepEqual(t, []*pb.PendingAttestation{}, newState.CurrentEpochAttestations(), "CurrentEpochAttestations was not correctly initialized")
assert.DeepEqual(t, []*pb.PendingAttestation{}, newState.PreviousEpochAttestations(), "PreviousEpochAttestations was not correctly initialized")
assert.DeepSSZEqual(t, []*pb.PendingAttestation{}, newState.CurrentEpochAttestations(), "CurrentEpochAttestations was not correctly initialized")
assert.DeepSSZEqual(t, []*pb.PendingAttestation{}, newState.PreviousEpochAttestations(), "PreviousEpochAttestations was not correctly initialized")

zeroHash := params.BeaconConfig().ZeroHash[:]
// History root checks.
Expand All @@ -82,7 +82,7 @@ func TestGenesisBeaconState_OK(t *testing.T) {

// Deposit root checks.
assert.DeepEqual(t, eth1Data.DepositRoot, newState.Eth1Data().DepositRoot, "Eth1Data DepositRoot was not correctly initialized")
assert.DeepEqual(t, []*ethpb.Eth1Data{}, newState.Eth1DataVotes(), "Eth1DataVotes was not correctly initialized")
assert.DeepSSZEqual(t, []*ethpb.Eth1Data{}, newState.Eth1DataVotes(), "Eth1DataVotes was not correctly initialized")
}

func TestGenesisState_HashEquality(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/operations/slashings/service_proposer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestPool_MarkIncludedProposerSlashing(t *testing.T) {
p.MarkIncludedProposerSlashing(tt.args.slashing)
assert.Equal(t, len(tt.want.pending), len(p.pendingProposerSlashing))
for i := range p.pendingProposerSlashing {
assert.DeepEqual(t, tt.want.pending[i], p.pendingProposerSlashing[i], "Unexpected pending proposer slashing at index %d", i)
assert.DeepSSZEqual(t, tt.want.pending[i], p.pendingProposerSlashing[i], "Unexpected pending proposer slashing at index %d", i)
}
assert.DeepEqual(t, tt.want.included, p.included)
})
Expand Down
10 changes: 10 additions & 0 deletions shared/testutil/assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ func DeepNotEqual(tb assertions.AssertionTestingTB, expected, actual interface{}
assertions.DeepNotEqual(tb.Errorf, expected, actual, msg...)
}

// DeepSSZEqual compares values using sszutil.DeepEqual.
func DeepSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepSSZEqual(tb.Errorf, expected, actual, msg...)
}

// DeepNotSSZEqual compares values using sszutil.DeepEqual.
func DeepNotSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepNotSSZEqual(tb.Errorf, expected, actual, msg...)
}

// NoError asserts that error is nil.
func NoError(tb assertions.AssertionTestingTB, err error, msg ...interface{}) {
assertions.NoError(tb.Errorf, err, msg...)
Expand Down
2 changes: 2 additions & 0 deletions shared/testutil/assertions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/shared/testutil/assertions",
visibility = ["//visibility:public"],
deps = [
"//shared/sszutil:go_default_library",
"@com_github_d4l3k_messagediff//:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
],
)
Expand Down
33 changes: 31 additions & 2 deletions shared/testutil/assertions/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"strings"

"github.com/d4l3k/messagediff"
"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/prysm/shared/sszutil"
"github.com/sirupsen/logrus/hooks/test"
)

Expand Down Expand Up @@ -39,7 +41,7 @@ func NotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...i

// DeepEqual compares values using DeepEqual.
func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !reflect.DeepEqual(expected, actual) {
if !isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
Expand All @@ -49,7 +51,26 @@ func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...

// DeepNotEqual compares values using DeepEqual.
func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if reflect.DeepEqual(expected, actual) {
if isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
}
}

// DeepSSZEqual compares values using sszutil.DeepEqual.
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !sszutil.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
}
}

// DeepNotSSZEqual compares values using sszutil.DeepEqual.
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if sszutil.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
Expand Down Expand Up @@ -144,6 +165,14 @@ func parseMsg(defaultMsg string, msg ...interface{}) string {
return defaultMsg
}

func isDeepEqual(expected, actual interface{}) bool {
_, isProto := expected.(proto.Message)
if isProto {
return proto.Equal(expected.(proto.Message), actual.(proto.Message))
}
return reflect.DeepEqual(expected, actual)
}

// TBMock exposes enough testing.TB methods for assertions.
type TBMock struct {
ErrorfMsg string
Expand Down
124 changes: 124 additions & 0 deletions shared/testutil/assertions/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,130 @@ func TestAssert_DeepNotEqual(t *testing.T) {
}
}

func TestAssert_DeepSSZEqual(t *testing.T) {
type args struct {
tb *assertions.TBMock
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
expectedResult bool
}{
{
name: "equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{42},
},
expectedResult: true,
},
{
name: "equal structs",
args: args{
tb: &assertions.TBMock{},
expected: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hi there"),
},
actual: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hi there"),
},
},
expectedResult: true,
},
{
name: "non-equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{41},
},
expectedResult: false,
},
}
for _, tt := range tests {
verify := func() {
if tt.expectedResult && tt.args.tb.ErrorfMsg != "" {
t.Errorf("Unexpected error: %s %v", tt.name, tt.args.tb.ErrorfMsg)
}
}
t.Run(fmt.Sprintf("Assert/%s", tt.name), func(t *testing.T) {
assert.DeepSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
t.Run(fmt.Sprintf("Require/%s", tt.name), func(t *testing.T) {
require.DeepSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
}
}

func TestAssert_DeepNotSSZEqual(t *testing.T) {
type args struct {
tb *assertions.TBMock
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
expectedResult bool
}{
{
name: "equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{42},
},
expectedResult: true,
},
{
name: "non-equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{41},
},
expectedResult: false,
},
{
name: "not equal structs",
args: args{
tb: &assertions.TBMock{},
expected: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hello there"),
},
actual: &eth.Checkpoint{
Epoch: 3,
Root: []byte("hi there"),
},
},
expectedResult: true,
},
}
for _, tt := range tests {
verify := func() {
if !tt.expectedResult && tt.args.tb.ErrorfMsg != "" {
t.Errorf("Unexpected error: %s %v", tt.name, tt.args.tb.ErrorfMsg)
}
}
t.Run(fmt.Sprintf("Assert/%s", tt.name), func(t *testing.T) {
assert.DeepNotSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
t.Run(fmt.Sprintf("Require/%s", tt.name), func(t *testing.T) {
require.DeepNotSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
}
}

func TestAssert_NoError(t *testing.T) {
type args struct {
tb *assertions.TBMock
Expand Down
10 changes: 10 additions & 0 deletions shared/testutil/require/requires.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ func DeepNotEqual(tb assertions.AssertionTestingTB, expected, actual interface{}
assertions.DeepNotEqual(tb.Fatalf, expected, actual, msg...)
}

// DeepSSZEqual compares values using DeepEqual.
func DeepSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepSSZEqual(tb.Fatalf, expected, actual, msg...)
}

// DeepNotSSZEqual compares values using DeepEqual.
func DeepNotSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepNotSSZEqual(tb.Fatalf, expected, actual, msg...)
}

// NoError asserts that error is nil.
func NoError(tb assertions.AssertionTestingTB, err error, msg ...interface{}) {
assertions.NoError(tb.Fatalf, err, msg...)
Expand Down

0 comments on commit cd3851c

Please sign in to comment.