Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heterogeneous equality #482

Merged
merged 2 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions common/types/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ func (b Bool) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bool) Equal(other ref.Val) ref.Val {
otherBool, ok := other.(Bool)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(b == otherBool)
return Bool(ok && b == otherBool)
}

// Negate implements the traits.Negater interface method.
Expand Down
4 changes: 2 additions & 2 deletions common/types/bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func TestBoolEqual(t *testing.T) {
if False.Equal(True).(Bool) {
t.Error("False was equal to true")
}
if !IsError(Double(0.0).Equal(False)) {
t.Error("Cross-type equality yielded non-error value.")
if Double(0.0).Equal(False) != False {
t.Error("Cross-type equality yielded error value.")
}
}

Expand Down
5 changes: 1 addition & 4 deletions common/types/bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ func (b Bytes) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bytes) Equal(other ref.Val) ref.Val {
otherBytes, ok := other.(Bytes)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(bytes.Equal(b, otherBytes))
return Bool(ok && bytes.Equal(b, otherBytes))
}

// Size implements the traits.Sizer interface method.
Expand Down
2 changes: 1 addition & 1 deletion common/types/double.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (d Double) Equal(other ref.Val) ref.Val {
case Uint:
return Bool(compareDoubleUint(d, ov) == 0)
default:
return MaybeNoSuchOverloadErr(other)
return False
}
}

Expand Down
5 changes: 0 additions & 5 deletions common/types/double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,6 @@ func TestDoubleEqual(t *testing.T) {
b: Int(10),
out: False,
},
{
a: Double(10),
b: Unknown{2},
out: Unknown{2},
},
}
for _, tc := range tests {
got := tc.a.Equal(tc.b)
Expand Down
5 changes: 1 addition & 4 deletions common/types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ func (d Duration) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (d Duration) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
return Bool(d.Duration == otherDur.Duration)
return Bool(ok && d.Duration == otherDur.Duration)
}

// Negate implements traits.Negater.Negate.
Expand Down
2 changes: 1 addition & 1 deletion common/types/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (i Int) Equal(other ref.Val) ref.Val {
case Uint:
return Bool(compareIntUint(i, ov) == 0)
default:
return MaybeNoSuchOverloadErr(other)
return False
}
}

Expand Down
5 changes: 0 additions & 5 deletions common/types/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,6 @@ func TestIntEqual(t *testing.T) {
b: Double(math.NaN()),
out: False,
},
{
a: Int(10),
b: Unknown{2},
out: Unknown{2},
},
}
for _, tc := range tests {
got := tc.a.Equal(tc.b)
Expand Down
8 changes: 4 additions & 4 deletions common/types/json_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func TestJsonListValueContains_MixedElemType(t *testing.T) {
// each element in the list. When the value is present, the result
// can be True. When the value is not present and the list is of
// mixed element type, the result is an error.
if !IsError(list.Contains(Double(2))) {
t.Error("Expected value list to not contain number '2' and error", list)
if list.Contains(Double(2)).(Bool) {
t.Error("Expected value list to not contain number '2'", list)
}
}

Expand Down Expand Up @@ -185,8 +185,8 @@ func TestJsonListValueEqual(t *testing.T) {
if listA.Add(listA).Equal(listB).(Bool) {
t.Error("Lists of different size were equal.")
}
if !IsError(listA.Equal(True)) {
t.Error("Equality of different type returned non-error.")
if IsError(listA.Equal(True)) {
t.Error("Equality of different type returned error.")
}
}

Expand Down
8 changes: 4 additions & 4 deletions common/types/json_struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ func TestJsonStructEqual(t *testing.T) {
if mapVal.Equal(NewJSONStruct(reg, &structpb.Struct{})) != False {
t.Error("Map with key-value pairs was equal to empty map")
}
if !IsError(mapVal.Equal(String(""))) {
t.Error("Map equal to a non-map type returned non-error.")
if IsError(mapVal.Equal(String(""))) {
t.Error("Map equal to a non-map type returned error, wanted 'false'")
}

other := NewJSONStruct(reg,
Expand All @@ -145,8 +145,8 @@ func TestJsonStructEqual(t *testing.T) {
map[int]interface{}{
1: "hello",
2: "world"})
if !IsError(mapVal.Equal(mismatch)) {
t.Error("Key type mismatch did not result in error")
if IsError(mapVal.Equal(mismatch)) {
t.Error("Key type mismatch resulted in error, wanted 'false'")
}
}

Expand Down
27 changes: 3 additions & 24 deletions common/types/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,14 @@ func (l *baseList) Add(other ref.Val) ref.Val {

// Contains implements the traits.Container interface method.
func (l *baseList) Contains(elem ref.Val) ref.Val {
if IsUnknownOrError(elem) {
return elem
}
var err ref.Val
for i := 0; i < l.size; i++ {
val := l.NativeToValue(l.get(i))
cmp := elem.Equal(val)
b, ok := cmp.(Bool)
// When there is an error on the contain check, this is not necessarily terminal.
// The contains call could find the element and return True, just as though the user
// had written a per-element comparison in an exists() macro or logical ||, e.g.
// list.exists(e, e == elem)
if !ok && err == nil {
err = ValOrErr(cmp, "no such overload")
}
if b == True {
if ok && b == True {
return True
}
}
if err != nil {
return err
}
return False
}

Expand Down Expand Up @@ -222,25 +208,18 @@ func (l *baseList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *baseList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
}
var maybeErr ref.Val
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := Equal(thisElem, otherElem)
if elemEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(elemEq) {
maybeErr = elemEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
Expand Down Expand Up @@ -349,7 +328,7 @@ func (l *concatList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *concatList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
Expand Down
17 changes: 8 additions & 9 deletions common/types/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ func TestBaseListContains(t *testing.T) {
},
{
in: String("3"),
out: NoSuchOverloadErr(),
},
{
in: Unknown{1},
out: Unknown{1},
out: False,
},
}
for _, tc := range tests {
Expand Down Expand Up @@ -174,8 +170,8 @@ func TestBaseListEqual(t *testing.T) {
if listA.Equal(listD) != False {
t.Error("listA.Equal(listD) did not return true")
}
if !IsError(listB.Equal(listD)) {
t.Error("listA.Equal(listD) did not error on single element type difference")
if IsError(listB.Equal(listD)) {
t.Error("listA.Equal(listD) errored, wanted 'false'")
}
}

Expand Down Expand Up @@ -355,8 +351,8 @@ func TestConcatListContainsNonBool(t *testing.T) {
listA := NewDynamicList(reg, []float32{1.0, 2.0})
listB := NewDynamicList(reg, []string{"3"})
listConcat := listA.Add(listB).(traits.Lister)
if !IsError(listConcat.Contains(String("4"))) {
t.Error("Contains did not error with list of mixed types an not found input.")
if IsError(listConcat.Contains(String("4"))) {
t.Error("Contains errored with a not-found element, wanted 'false'")
}
}

Expand All @@ -381,6 +377,9 @@ func TestConcatListEqual(t *testing.T) {
if list.Equal(listD) != True {
t.Errorf("list.Equal(listD) got %v, wanted true", list.Equal(listD))
}
if list.Equal(NullValue) != False {
t.Errorf("list.Equal(NullValue) got %v, wanted false", list.Equal(NullValue))
}
}

func TestConcatListGet(t *testing.T) {
Expand Down
Loading