Skip to content

Commit

Permalink
CmpWalkStructAreEqual
Browse files Browse the repository at this point in the history
  • Loading branch information
m-szalik committed Jul 24, 2024
1 parent 70b918e commit fe14044
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
89 changes: 88 additions & 1 deletion reflect_structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,93 @@ import (
"strings"
)

type CmpError interface {
error
FieldPath() string
A() any
B() any
}

type cmpErrorImpl struct {
fieldPath string
msg string
a any
b any
}

func (c cmpErrorImpl) A() any {
return c.a
}

func (c cmpErrorImpl) B() any {
return c.b
}

func (c cmpErrorImpl) FieldPath() string {
return c.fieldPath
}

func (c cmpErrorImpl) Error() string {
return fmt.Sprintf("field: %s - %s", c.fieldPath, c.msg)
}

func cmpError(field string, a, b any, msg string) CmpError {
return &cmpErrorImpl{field, msg, a, b}
}

func CmpWalkStructAreEqual(a interface{}, b interface{}) CmpError {
valA := elemValue(a)
valB := elemValue(b)
return cmpValue(valA, valB, "")
}

func cmpValue(a reflect.Value, b reflect.Value, fieldPath string) (eError CmpError) {
defer func() {
if r := recover(); r != nil {
eError = cmpError(fieldPath, a, b, eError.Error())
}
}()
if a.Type() != b.Type() || a.Kind() != b.Kind() {
return cmpError(fieldPath, a, b, fmt.Sprintf("A and B must be of the same type but %T and %T had been given", a, b))
}
if !a.IsValid() || !b.IsValid() {
return cmpError(fieldPath, a, b, "not valid element")
}
switch a.Kind() {
case reflect.Pointer:
if a.IsNil() != b.IsNil() {
return cmpError(fieldPath, a, b, "A and B must be of the same but one is nil")
}
return cmpValue(a.Elem(), b.Elem(), fieldPath)
case reflect.Struct:
for i := 0; i < a.NumField(); i++ {
aFieldValue := a.Field(i)
bFieldValue := b.Field(i)
structKey := a.Type().Field(i).Name
if err := cmpValue(aFieldValue, bFieldValue, fmt.Sprintf("%s.%s", fieldPath, structKey)); err != nil {
return err
}
}
case reflect.Array, reflect.Slice:
if a.Len() != b.Len() {
return cmpError(fieldPath, a, b, "arrays are not the same size")
}
if a.IsNil() != b.IsNil() {
return cmpError(fieldPath, a, b, "A and B must be of the same but one is nil")
}
for i := 0; i < a.Len(); i++ {
if err := cmpValue(a.Index(i), b.Index(i), fmt.Sprintf("%s[%d]", fieldPath, i)); err != nil {
return err
}
}
default:
if a.Interface() != b.Interface() {
return cmpError(fieldPath, a.Interface(), b.Interface(), fmt.Sprintf("%v != %v", a.Interface(), b.Interface()))
}
}
return nil
}

// AcceptFunc - function used by CopyStruct function.
type AcceptFunc func(fieldPath string, srcValue reflect.Value) bool

Expand All @@ -17,7 +104,7 @@ func CopyStruct(src interface{}, dst interface{}, acceptFunc AcceptFunc) error {
}
source := elemValue(src)
destination := elemValue(dst)
if source.Type() != destination.Type() {
if source.Type() != destination.Type() || source.Kind() != destination.Kind() {
return fmt.Errorf("source and destination must be of the same type but %T and %T had been given", source, destination)
}
return copyValue(source, destination, "", acceptFunc)
Expand Down
33 changes: 33 additions & 0 deletions reflect_structs_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package goutils

import (
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"reflect"
"testing"
Expand Down Expand Up @@ -95,3 +97,34 @@ func TestCopyStruct(t *testing.T) {
})

}

func TestCmpWalkStructAreEqual(t *testing.T) {
type tStruc struct {
X int
Array []string
}
tests := []struct {
name string
a interface{}
b interface{}
noErrorExpected bool
}{
{name: "not equal numbers", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 2, Array: []string{}}},
{name: "not equal array len", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: []string{"ax"}}},
{name: "not equal array nil", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: nil}},
{name: "equal", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: []string{}}, noErrorExpected: true},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("CmpWalkStructAreEqual-%s", tt.name), func(t *testing.T) {
err := CmpWalkStructAreEqual(tt.a, tt.b)
if tt.noErrorExpected {
assert.NoError(t, err)
} else {
assert.Error(t, err)
var cmpErr CmpError
errors.As(err, &cmpErr)
fmt.Printf(" CMP %s --> A=%v, B=%v\n", cmpErr.Error(), cmpErr.A(), cmpErr.B())
}
})
}
}

0 comments on commit fe14044

Please sign in to comment.