diff --git a/matchers.go b/matchers.go index 223f6ef53..dc2706331 100644 --- a/matchers.go +++ b/matchers.go @@ -485,10 +485,15 @@ func Not(matcher types.GomegaMatcher) types.GomegaMatcher { } //WithTransform applies the `transform` to the actual value and matches it against `matcher`. -//The given transform must be a function of one parameter that returns one value. +//The given transform must be either a function of one parameter that returns one value or a +// function of one parameter that returns two values, where the second value must be of the +// error type. // var plus1 = func(i int) int { return i + 1 } // Expect(1).To(WithTransform(plus1, Equal(2)) // +// var failingplus1 = func(i int) (int, error) { return 42, "this does not compute" } +// Expect(1).To(WithTrafo(failingplus1, Equal(2))) +// //And(), Or(), Not() and WithTransform() allow matchers to be composed into complex expressions. func WithTransform(transform interface{}, matcher types.GomegaMatcher) types.GomegaMatcher { return matchers.NewWithTransformMatcher(transform, matcher) diff --git a/matchers/with_transform.go b/matchers/with_transform.go index 8a06bd384..07caa5bfa 100644 --- a/matchers/with_transform.go +++ b/matchers/with_transform.go @@ -9,7 +9,7 @@ import ( type WithTransformMatcher struct { // input - Transform interface{} // must be a function of one parameter that returns one value + Transform interface{} // must be a function of one parameter that returns one value and an optional error Matcher types.GomegaMatcher // cached value @@ -19,6 +19,9 @@ type WithTransformMatcher struct { transformedValue interface{} } +// reflect.Type for error +var errorT = reflect.TypeOf((*error)(nil)).Elem() + func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher) *WithTransformMatcher { if transform == nil { panic("transform function cannot be nil") @@ -27,8 +30,10 @@ func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher) if txType.NumIn() != 1 { panic("transform function must have 1 argument") } - if txType.NumOut() != 1 { - panic("transform function must have 1 return value") + if numout := txType.NumOut(); numout != 1 { + if numout != 2 || !txType.Out(1).AssignableTo(errorT) { + panic("transform function must either have 1 return value, or 1 return value plus 1 error value") + } } return &WithTransformMatcher{ @@ -57,6 +62,11 @@ func (m *WithTransformMatcher) Match(actual interface{}) (bool, error) { // call the Transform function with `actual` fn := reflect.ValueOf(m.Transform) result := fn.Call([]reflect.Value{param}) + if len(result) == 2 { + if !result[1].IsNil() { + return false, fmt.Errorf("Transform function failed: %e", result[1].Interface()) + } + } m.transformedValue = result[0].Interface() // expect exactly one value return m.Matcher.Match(m.transformedValue) diff --git a/matchers/with_transform_test.go b/matchers/with_transform_test.go index 38436fe75..570cd7c33 100644 --- a/matchers/with_transform_test.go +++ b/matchers/with_transform_test.go @@ -35,6 +35,11 @@ var _ = Describe("WithTransformMatcher", func() { panicsWithTransformer(func(i int) (int, int) { return 5, 6 }) }) }) + Context("Invalid number of return values, but correct number of arguments", func() { + It("Two return values, but second return value not an error", func() { + panicsWithTransformer(func(interface{}) (int, int) { return 5, 6 }) + }) + }) }) When("the actual value is incompatible", func() { @@ -121,6 +126,16 @@ var _ = Describe("WithTransformMatcher", func() { }) }) + When("transform fails", func() { + It("reports the transformation error", func() { + actual, trafo := "foo", func(string) (string, error) { return "", errors.New("that does not transform") } + success, err := WithTransform(trafo, Equal(actual)).Match(actual) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("that does not transform")) + }) + }) + Context("actual value is incompatible with transform function's argument type", func() { It("gracefully fails if transform cannot be performed", func() { m := WithTransform(plus1, Equal(3))