Skip to content

Commit

Permalink
feat: add Compare using go-cmp (#546)
Browse files Browse the repository at this point in the history
* feat: add CompareEqual using go-cmp

* fix: use recover to let it as usual error handling

* refactor: rename Compare -> BeComparableTo
  • Loading branch information
xiantank authored Apr 27, 2022
1 parent 1c29028 commit e77ea75
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 2 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.18

require (
github.com/golang/protobuf v1.5.2
github.com/google/go-cmp v0.5.7
github.com/onsi/ginkgo/v2 v2.1.4
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4
gopkg.in/yaml.v2 v2.4.0
Expand All @@ -12,5 +13,6 @@ require (
require (
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.0 // indirect
)
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/onsi/ginkgo/v2 v2.1.4 h1:GNapqRSid3zijZ9H77KrgVG4/8KqiyRsxcSxe+7ApXY=
github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA=
Expand All @@ -11,8 +12,9 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16C
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
Expand Down
10 changes: 10 additions & 0 deletions matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gomega
import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/onsi/gomega/matchers"
"github.com/onsi/gomega/types"
)
Expand All @@ -26,6 +27,15 @@ func BeEquivalentTo(expected interface{}) types.GomegaMatcher {
}
}

//BeComparableTo uses gocmp.Equal to compare. You can pass cmp.Option as options.
//It is an error for actual and expected to be nil. Use BeNil() instead.
func BeComparableTo(expected interface{}, opts ...cmp.Option) types.GomegaMatcher {
return &matchers.BeComparableToMatcher{
Expected: expected,
Options: opts,
}
}

//BeIdenticalTo uses the == operator to compare actual with expected.
//BeIdenticalTo is strict about types when performing comparisons.
//It is an error for both actual and expected to be nil. Use BeNil() instead.
Expand Down
48 changes: 48 additions & 0 deletions matchers/be_comparable_to_matcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package matchers

import (
"bytes"
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/onsi/gomega/format"
)

type BeComparableToMatcher struct {
Expected interface{}
Options cmp.Options
}

func (matcher *BeComparableToMatcher) Match(actual interface{}) (success bool, matchErr error) {
if actual == nil && matcher.Expected == nil {
return false, fmt.Errorf("Refusing to compare <nil> to <nil>.\nBe explicit and use BeNil() instead. This is to avoid mistakes where both sides of an assertion are erroneously uninitialized.")
}
// Shortcut for byte slices.
// Comparing long byte slices with reflect.DeepEqual is very slow,
// so use bytes.Equal if actual and expected are both byte slices.
if actualByteSlice, ok := actual.([]byte); ok {
if expectedByteSlice, ok := matcher.Expected.([]byte); ok {
return bytes.Equal(actualByteSlice, expectedByteSlice), nil
}
}

defer func() {
if r := recover(); r != nil {
success = false
if err, ok := r.(error); ok {
matchErr = err
} else if errMsg, ok := r.(string); ok {
matchErr = fmt.Errorf(errMsg)
}
}
}()

return cmp.Equal(actual, matcher.Expected, matcher.Options...), nil
}

func (matcher *BeComparableToMatcher) FailureMessage(actual interface{}) (message string) {
return cmp.Diff(matcher.Expected, actual)
}

func (matcher *BeComparableToMatcher) NegatedFailureMessage(actual interface{}) (message string) {
return format.Message(actual, "not to equal", matcher.Expected)
}
130 changes: 130 additions & 0 deletions matchers/be_comparable_to_matcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package matchers_test

import (
"errors"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/matchers"
)

type wrapError struct {
msg string
err error
}

func (e wrapError) Error() string {
return e.msg
}

func (e wrapError) Unwrap() error {
return e.err
}

var _ = Describe("BeComparableTo", func() {
When("asserting that nil is comparable to nil", func() {
It("should error", func() {
success, err := (&BeComparableToMatcher{Expected: nil}).Match(nil)

Expect(success).Should(BeFalse())
Expect(err).Should(HaveOccurred())
})
})

Context("When asserting on nil", func() {
It("should do the right thing", func() {
Expect("foo").ShouldNot(BeComparableTo(nil))
Expect(nil).ShouldNot(BeComparableTo(3))
Expect([]int{1, 2}).ShouldNot(BeComparableTo(nil))
})
})

Context("When asserting time with different location ", func() {
var t1, t2, t3 time.Time

BeforeEach(func() {
t1 = time.Time{}
t2 = time.Time{}.Local()
t3 = t1.Add(time.Second)
})

It("should do the right thing", func() {
Expect(t1).Should(BeComparableTo(t2))
Expect(t1).ShouldNot(BeComparableTo(t3))
})
})

Context("When struct contain unexported fields", func() {
type structWithUnexportedFields struct {
unexported string
Exported string
}

var s1, s2 structWithUnexportedFields

BeforeEach(func() {
s1 = structWithUnexportedFields{unexported: "unexported", Exported: "Exported"}
s2 = structWithUnexportedFields{unexported: "unexported", Exported: "Exported"}
})

It("should get match err", func() {
success, err := (&BeComparableToMatcher{Expected: s1}).Match(s2)
Expect(success).Should(BeFalse())
Expect(err).Should(HaveOccurred())
})

It("should do the right thing", func() {
Expect(s1).Should(BeComparableTo(s2, cmpopts.IgnoreUnexported(structWithUnexportedFields{})))
})
})

Context("When compare error", func() {
var err1, err2 error

It("not equal", func() {
err1 = errors.New("error")
err2 = errors.New("error")
Expect(err1).ShouldNot(BeComparableTo(err2, cmpopts.EquateErrors()))
})

It("equal if err1 is err2", func() {
err1 = errors.New("error")
err2 = &wrapError{
msg: "some error",
err: err1,
}

Expect(err1).Should(BeComparableTo(err2, cmpopts.EquateErrors()))
})
})

Context("When asserting equal between objects", func() {
It("should do the right thing", func() {
Expect(5).Should(BeComparableTo(5))
Expect(5.0).Should(BeComparableTo(5.0))

Expect(5).ShouldNot(BeComparableTo("5"))
Expect(5).ShouldNot(BeComparableTo(5.0))
Expect(5).ShouldNot(BeComparableTo(3))

Expect("5").Should(BeComparableTo("5"))
Expect([]int{1, 2}).Should(BeComparableTo([]int{1, 2}))
Expect([]int{1, 2}).ShouldNot(BeComparableTo([]int{2, 1}))
Expect([]byte{'f', 'o', 'o'}).Should(BeComparableTo([]byte{'f', 'o', 'o'}))
Expect([]byte{'f', 'o', 'o'}).ShouldNot(BeComparableTo([]byte{'b', 'a', 'r'}))
Expect(map[string]string{"a": "b", "c": "d"}).Should(BeComparableTo(map[string]string{"a": "b", "c": "d"}))
Expect(map[string]string{"a": "b", "c": "d"}).ShouldNot(BeComparableTo(map[string]string{"a": "b", "c": "e"}))

Expect(myCustomType{s: "abc", n: 3, f: 2.0, arr: []string{"a", "b"}}).Should(BeComparableTo(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}, cmpopts.IgnoreUnexported(myCustomType{})))

Expect(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}).Should(BeComparableTo(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}, cmp.AllowUnexported(myCustomType{})))
Expect(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}).ShouldNot(BeComparableTo(myCustomType{s: "bar", n: 3, f: 2.0, arr: []string{"a", "b"}}, cmp.AllowUnexported(myCustomType{})))
Expect(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}).ShouldNot(BeComparableTo(myCustomType{s: "foo", n: 2, f: 2.0, arr: []string{"a", "b"}}, cmp.AllowUnexported(myCustomType{})))
Expect(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}).ShouldNot(BeComparableTo(myCustomType{s: "foo", n: 3, f: 3.0, arr: []string{"a", "b"}}, cmp.AllowUnexported(myCustomType{})))
Expect(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b"}}).ShouldNot(BeComparableTo(myCustomType{s: "foo", n: 3, f: 2.0, arr: []string{"a", "b", "c"}}, cmp.AllowUnexported(myCustomType{})))
})
})
})

0 comments on commit e77ea75

Please sign in to comment.