diff --git a/deep.go b/deep.go index ea6265c..4fc9566 100644 --- a/deep.go +++ b/deep.go @@ -41,13 +41,55 @@ var ( ErrNotHandled = errors.New("cannot compare the reflect.Kind") ) -type cmp struct { - diff []string - buff []string - floatFormat string +// Comparer is a struct capturing the configuration used for Equals(). The package +// Equal() function uses a default Comparer struct which references the global +// variables that control the execution of the equality algorithm. +type Comparer struct { + FloatPrecision *int + MaxDiff *int + MaxDepth *int + LogErrors *bool + CompareUnexportedFields *bool + ErrMaxRecursion *error + ErrTypeMismatch *error + ErrNotHandled *error } -var errorType = reflect.TypeOf((*error)(nil)).Elem() +// MakeComparer returns a Comparer struct with nil fields initialized to point +// to the global settings. +func MakeComparer(c Comparer) Comparer { + if c.FloatPrecision == nil { + c.FloatPrecision = &FloatPrecision + } + if c.MaxDiff == nil { + c.MaxDiff = &MaxDiff + } + if c.MaxDepth == nil { + c.MaxDepth = &MaxDepth + } + if c.LogErrors == nil { + c.LogErrors = &LogErrors + } + if c.CompareUnexportedFields == nil { + c.CompareUnexportedFields = &CompareUnexportedFields + } + if c.ErrMaxRecursion == nil { + c.ErrMaxRecursion = &ErrMaxRecursion + } + if c.ErrTypeMismatch == nil { + c.ErrTypeMismatch = &ErrTypeMismatch + } + if c.ErrNotHandled == nil { + c.ErrNotHandled = &ErrNotHandled + } + return c +} + +func makeDefaultComparer() Comparer { + return MakeComparer(Comparer{}) +} + +var DefaultComparer = makeDefaultComparer() // Equal compares variables a and b, recursing into their structure up to // MaxDepth levels deep, and returns a list of differences, or nil if there are @@ -56,12 +98,26 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem() // If a type has an Equal method, like time.Equal, it is called to check for // equality. func Equal(a, b interface{}) []string { + return DefaultComparer.Equal(a, b) +} + +type cmp struct { + diff []string + buff []string + floatFormat string + *Comparer +} + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +func (cp *Comparer) Equal(a, b interface{}) []string { aVal := reflect.ValueOf(a) bVal := reflect.ValueOf(b) c := &cmp{ diff: []string{}, buff: []string{}, - floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), + floatFormat: fmt.Sprintf("%%.%df", *cp.FloatPrecision), + Comparer: cp, } if a == nil && b == nil { return nil @@ -82,8 +138,8 @@ func Equal(a, b interface{}) []string { } func (c *cmp) equals(a, b reflect.Value, level int) { - if level > MaxDepth { - logError(ErrMaxRecursion) + if level > *c.MaxDepth { + c.logError(*c.ErrMaxRecursion) return } @@ -102,7 +158,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { bType := b.Type() if aType != bType { c.saveDiff(aType, bType) - logError(ErrTypeMismatch) + c.logError(*c.ErrTypeMismatch) return } @@ -181,7 +237,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } for i := 0; i < a.NumField(); i++ { - if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { + if aType.Field(i).PkgPath != "" && !*c.CompareUnexportedFields { continue // skip unexported field, e.g. s in type T struct {s string} } @@ -197,7 +253,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() // pop field name from buff - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -243,7 +299,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { return } } @@ -256,7 +312,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("map[%s]", key)) c.saveDiff("", b.MapIndex(key)) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { return } } @@ -266,7 +322,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("array[%d]", i)) c.equals(a.Index(i), b.Index(i), level+1) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -300,7 +356,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.saveDiff("", b.Index(i)) } c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -335,7 +391,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } default: - logError(ErrNotHandled) + c.logError(*c.ErrNotHandled) } } @@ -358,8 +414,8 @@ func (c *cmp) saveDiff(aval, bval interface{}) { } } -func logError(err error) { - if LogErrors { +func (c *cmp) logError(err error) { + if *c.LogErrors { log.Println(err) } }