Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Equal() to be used by different packages in one program. #20

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions deep.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,55 @@ var (
ErrNotHandled = errors.New("cannot compare the reflect.Kind")
)

type cmp struct {
diff []string
buff []string
floatFormat string
// Comparer is a struct capturing the configuration used for Equals(). The package
// Equal() function uses a default Comparer struct which references the global
// variables that control the execution of the equality algorithm.
type Comparer struct {
FloatPrecision *int
MaxDiff *int
MaxDepth *int
LogErrors *bool
CompareUnexportedFields *bool
ErrMaxRecursion *error
ErrTypeMismatch *error
ErrNotHandled *error
}

var errorType = reflect.TypeOf((*error)(nil)).Elem()
// MakeComparer returns a Comparer struct with nil fields initialized to point
// to the global settings.
func MakeComparer(c Comparer) Comparer {
if c.FloatPrecision == nil {
c.FloatPrecision = &FloatPrecision
}
if c.MaxDiff == nil {
c.MaxDiff = &MaxDiff
}
if c.MaxDepth == nil {
c.MaxDepth = &MaxDepth
}
if c.LogErrors == nil {
c.LogErrors = &LogErrors
}
if c.CompareUnexportedFields == nil {
c.CompareUnexportedFields = &CompareUnexportedFields
}
if c.ErrMaxRecursion == nil {
c.ErrMaxRecursion = &ErrMaxRecursion
}
if c.ErrTypeMismatch == nil {
c.ErrTypeMismatch = &ErrTypeMismatch
}
if c.ErrNotHandled == nil {
c.ErrNotHandled = &ErrNotHandled
}
return c
}

func makeDefaultComparer() Comparer {
return MakeComparer(Comparer{})
}

var DefaultComparer = makeDefaultComparer()

// Equal compares variables a and b, recursing into their structure up to
// MaxDepth levels deep, and returns a list of differences, or nil if there are
Expand All @@ -56,12 +98,26 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem()
// If a type has an Equal method, like time.Equal, it is called to check for
// equality.
func Equal(a, b interface{}) []string {
return DefaultComparer.Equal(a, b)
}

type cmp struct {
diff []string
buff []string
floatFormat string
*Comparer
}

var errorType = reflect.TypeOf((*error)(nil)).Elem()

func (cp *Comparer) Equal(a, b interface{}) []string {
aVal := reflect.ValueOf(a)
bVal := reflect.ValueOf(b)
c := &cmp{
diff: []string{},
buff: []string{},
floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
floatFormat: fmt.Sprintf("%%.%df", *cp.FloatPrecision),
Comparer: cp,
}
if a == nil && b == nil {
return nil
Expand All @@ -82,8 +138,8 @@ func Equal(a, b interface{}) []string {
}

func (c *cmp) equals(a, b reflect.Value, level int) {
if level > MaxDepth {
logError(ErrMaxRecursion)
if level > *c.MaxDepth {
c.logError(*c.ErrMaxRecursion)
return
}

Expand All @@ -102,7 +158,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
bType := b.Type()
if aType != bType {
c.saveDiff(aType, bType)
logError(ErrTypeMismatch)
c.logError(*c.ErrTypeMismatch)
return
}

Expand Down Expand Up @@ -181,7 +237,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
}

for i := 0; i < a.NumField(); i++ {
if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
if aType.Field(i).PkgPath != "" && !*c.CompareUnexportedFields {
continue // skip unexported field, e.g. s in type T struct {s string}
}

Expand All @@ -197,7 +253,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {

c.pop() // pop field name from buff

if len(c.diff) >= MaxDiff {
if len(c.diff) >= *c.MaxDiff {
break
}
}
Expand Down Expand Up @@ -243,7 +299,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {

c.pop()

if len(c.diff) >= MaxDiff {
if len(c.diff) >= *c.MaxDiff {
return
}
}
Expand All @@ -256,7 +312,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.push(fmt.Sprintf("map[%s]", key))
c.saveDiff("<does not have key>", b.MapIndex(key))
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= *c.MaxDiff {
return
}
}
Expand All @@ -266,7 +322,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.push(fmt.Sprintf("array[%d]", i))
c.equals(a.Index(i), b.Index(i), level+1)
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= *c.MaxDiff {
break
}
}
Expand Down Expand Up @@ -300,7 +356,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.saveDiff("<no value>", b.Index(i))
}
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= *c.MaxDiff {
break
}
}
Expand Down Expand Up @@ -335,7 +391,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
}

default:
logError(ErrNotHandled)
c.logError(*c.ErrNotHandled)
}
}

Expand All @@ -358,8 +414,8 @@ func (c *cmp) saveDiff(aval, bval interface{}) {
}
}

func logError(err error) {
if LogErrors {
func (c *cmp) logError(err error) {
if *c.LogErrors {
log.Println(err)
}
}