diff --git a/reflect_structs.go b/reflect_structs.go index c816b78..6c5e82b 100644 --- a/reflect_structs.go +++ b/reflect_structs.go @@ -7,6 +7,93 @@ import ( "strings" ) +type CmpError interface { + error + FieldPath() string + A() any + B() any +} + +type cmpErrorImpl struct { + fieldPath string + msg string + a any + b any +} + +func (c cmpErrorImpl) A() any { + return c.a +} + +func (c cmpErrorImpl) B() any { + return c.b +} + +func (c cmpErrorImpl) FieldPath() string { + return c.fieldPath +} + +func (c cmpErrorImpl) Error() string { + return fmt.Sprintf("field: %s - %s", c.fieldPath, c.msg) +} + +func cmpError(field string, a, b any, msg string) CmpError { + return &cmpErrorImpl{field, msg, a, b} +} + +func CmpWalkStructAreEqual(a interface{}, b interface{}) CmpError { + valA := elemValue(a) + valB := elemValue(b) + return cmpValue(valA, valB, "") +} + +func cmpValue(a reflect.Value, b reflect.Value, fieldPath string) (eError CmpError) { + defer func() { + if r := recover(); r != nil { + eError = cmpError(fieldPath, a, b, eError.Error()) + } + }() + if a.Type() != b.Type() || a.Kind() != b.Kind() { + return cmpError(fieldPath, a, b, fmt.Sprintf("A and B must be of the same type but %T and %T had been given", a, b)) + } + if !a.IsValid() || !b.IsValid() { + return cmpError(fieldPath, a, b, "not valid element") + } + switch a.Kind() { + case reflect.Pointer: + if a.IsNil() != b.IsNil() { + return cmpError(fieldPath, a, b, "A and B must be of the same but one is nil") + } + return cmpValue(a.Elem(), b.Elem(), fieldPath) + case reflect.Struct: + for i := 0; i < a.NumField(); i++ { + aFieldValue := a.Field(i) + bFieldValue := b.Field(i) + structKey := a.Type().Field(i).Name + if err := cmpValue(aFieldValue, bFieldValue, fmt.Sprintf("%s.%s", fieldPath, structKey)); err != nil { + return err + } + } + case reflect.Array, reflect.Slice: + if a.Len() != b.Len() { + return cmpError(fieldPath, a, b, "arrays are not the same size") + } + if a.IsNil() != b.IsNil() { + return cmpError(fieldPath, a, b, "A and B must be of the same but one is nil") + } + for i := 0; i < a.Len(); i++ { + if err := cmpValue(a.Index(i), b.Index(i), fmt.Sprintf("%s[%d]", fieldPath, i)); err != nil { + return err + } + } + default: + if a.Interface() != b.Interface() { + return cmpError(fieldPath, a.Interface(), b.Interface(), fmt.Sprintf("%v != %v", a.Interface(), b.Interface())) + } + } + return nil +} + // AcceptFunc - function used by CopyStruct function. type AcceptFunc func(fieldPath string, srcValue reflect.Value) bool @@ -17,7 +104,7 @@ func CopyStruct(src interface{}, dst interface{}, acceptFunc AcceptFunc) error { } source := elemValue(src) destination := elemValue(dst) - if source.Type() != destination.Type() { + if source.Type() != destination.Type() || source.Kind() != destination.Kind() { return fmt.Errorf("source and destination must be of the same type but %T and %T had been given", source, destination) } return copyValue(source, destination, "", acceptFunc) diff --git a/reflect_structs_test.go b/reflect_structs_test.go index 1602098..028ece8 100644 --- a/reflect_structs_test.go +++ b/reflect_structs_test.go @@ -1,6 +1,8 @@ package goutils import ( + "errors" + "fmt" "github.com/stretchr/testify/assert" "reflect" "testing" @@ -95,3 +97,34 @@ func TestCopyStruct(t *testing.T) { }) } + +func TestCmpWalkStructAreEqual(t *testing.T) { + type tStruc struct { + X int + Array []string + } + tests := []struct { + name string + a interface{} + b interface{} + noErrorExpected bool + }{ + {name: "not equal numbers", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 2, Array: []string{}}}, + {name: "not equal array len", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: []string{"ax"}}}, + {name: "not equal array nil", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: nil}}, + {name: "equal", a: tStruc{X: 12, Array: []string{}}, b: tStruc{X: 12, Array: []string{}}, noErrorExpected: true}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("CmpWalkStructAreEqual-%s", tt.name), func(t *testing.T) { + err := CmpWalkStructAreEqual(tt.a, tt.b) + if tt.noErrorExpected { + assert.NoError(t, err) + } else { + assert.Error(t, err) + var cmpErr CmpError + errors.As(err, &cmpErr) + fmt.Printf(" CMP %s --> A=%v, B=%v\n", cmpErr.Error(), cmpErr.A(), cmpErr.B()) + } + }) + } +}