diff --git a/common/types/BUILD.bazel b/common/types/BUILD.bazel index 32789f54..63842886 100644 --- a/common/types/BUILD.bazel +++ b/common/types/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "any_value.go", "bool.go", "bytes.go", + "compare.go", "double.go", "duration.go", "err.go", diff --git a/common/types/compare.go b/common/types/compare.go new file mode 100644 index 00000000..a6150ee7 --- /dev/null +++ b/common/types/compare.go @@ -0,0 +1,95 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "math" +) + +func compareDoubleInt(d Double, i Int) Int { + if d < math.MinInt64 { + return IntNegOne + } + if d > math.MaxInt64 { + return IntOne + } + return compareDouble(d, Double(i)) +} + +func compareIntDouble(i Int, d Double) Int { + return -compareDoubleInt(d, i) +} + +func compareDoubleUint(d Double, u Uint) Int { + if d < 0 { + return IntNegOne + } + if d > math.MaxUint64 { + return IntOne + } + return compareDouble(d, Double(u)) +} + +func compareUintDouble(u Uint, d Double) Int { + return -compareDoubleUint(d, u) +} + +func compareIntUint(i Int, u Uint) Int { + if i < 0 || u > math.MaxInt64 { + return IntNegOne + } + cmp := i - Int(u) + if cmp < 0 { + return IntNegOne + } + if cmp > 0 { + return IntOne + } + return IntZero +} + +func compareUintInt(u Uint, i Int) Int { + return -compareIntUint(i, u) +} + +func compareDouble(a, b Double) Int { + if a < b { + return IntNegOne + } + if a > b { + return IntOne + } + return IntZero +} + +func compareInt(a, b Int) Int { + if a < b { + return IntNegOne + } + if a > b { + return IntOne + } + return IntZero +} + +func compareUint(a, b Uint) Int { + if a < b { + return IntNegOne + } + if a > b { + return IntOne + } + return IntZero +} diff --git a/common/types/double.go b/common/types/double.go index 2a679137..ed094546 100644 --- a/common/types/double.go +++ b/common/types/double.go @@ -16,6 +16,7 @@ package types import ( "fmt" + "math" "reflect" "github.com/google/cel-go/common/types/ref" @@ -58,17 +59,22 @@ func (d Double) Add(other ref.Val) ref.Val { // Compare implements traits.Comparer.Compare. func (d Double) Compare(other ref.Val) ref.Val { - otherDouble, ok := other.(Double) - if !ok { - return MaybeNoSuchOverloadErr(other) + if math.IsNaN(float64(d)) { + return NewErr("NaN values cannot be ordered") } - if d < otherDouble { - return IntNegOne - } - if d > otherDouble { - return IntOne + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return NewErr("NaN values cannot be ordered") + } + return compareDouble(d, ov) + case Int: + return compareDoubleInt(d, ov) + case Uint: + return compareDoubleUint(d, ov) + default: + return MaybeNoSuchOverloadErr(other) } - return IntZero } // ConvertToNative implements ref.Val.ConvertToNative. @@ -158,12 +164,22 @@ func (d Double) Divide(other ref.Val) ref.Val { // Equal implements ref.Val.Equal. func (d Double) Equal(other ref.Val) ref.Val { - otherDouble, ok := other.(Double) - if !ok { + if math.IsNaN(float64(d)) { + return False + } + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return False + } + return Bool(d == ov) + case Int: + return Bool(compareDoubleInt(d, ov) == 0) + case Uint: + return Bool(compareDoubleUint(d, ov) == 0) + default: return MaybeNoSuchOverloadErr(other) } - // TODO: Handle NaNs properly. - return Bool(d == otherDouble) } // Multiply implements traits.Multiplier.Multiply. diff --git a/common/types/double_test.go b/common/types/double_test.go index 683690aa..d31cb2ef 100644 --- a/common/types/double_test.go +++ b/common/types/double_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "google.golang.org/protobuf/proto" anypb "google.golang.org/protobuf/types/known/anypb" @@ -39,19 +40,93 @@ func TestDoubleAdd(t *testing.T) { } func TestDoubleCompare(t *testing.T) { - lt := Double(-1300) - gt := Double(204) - if !lt.Compare(gt).Equal(IntNegOne).(Bool) { - t.Error("Comparison did not yield - 1") - } - if !gt.Compare(lt).Equal(IntOne).(Bool) { - t.Error("Comparison did not yield 1") - } - if !gt.Compare(gt).Equal(IntZero).(Bool) { - t.Error(("Comparison did not yield 0")) + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Double(42), + b: Double(42), + out: IntZero, + }, + { + a: Double(42), + b: Uint(42), + out: IntZero, + }, + { + a: Double(42), + b: Int(42), + out: IntZero, + }, + { + a: Double(-1300), + b: Double(204), + out: IntNegOne, + }, + { + a: Double(-1300), + b: Uint(204), + out: IntNegOne, + }, + { + a: Double(203.9), + b: Int(204), + out: IntNegOne, + }, + { + a: Double(1300), + b: Uint(math.MaxInt64) + 1, + out: IntNegOne, + }, + { + a: Double(204), + b: Uint(205), + out: IntNegOne, + }, + { + a: Double(204), + b: Double(math.MaxInt64) + 1025.0, + out: IntNegOne, + }, + { + a: Double(204), + b: Double(math.NaN()), + out: NewErr("NaN values cannot be ordered"), + }, + { + a: Double(math.NaN()), + b: Double(204), + out: NewErr("NaN values cannot be ordered"), + }, + { + a: Double(204), + b: Double(-1300), + out: IntOne, + }, + { + a: Double(204), + b: Uint(10), + out: IntOne, + }, + { + a: Double(204.1), + b: Int(204), + out: IntOne, + }, + { + a: Double(1), + b: String("1"), + out: NoSuchOverloadErr(), + }, } - if !IsError(gt.Compare(TypeType)) { - t.Error("Types not comparable") + for _, tc := range tests { + comparer := tc.a.(traits.Comparer) + got := comparer.Compare(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Compare(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } @@ -291,8 +366,57 @@ func TestDoubleDivide(t *testing.T) { } func TestDoubleEqual(t *testing.T) { - if !IsError(Double(0).Equal(False)) { - t.Error("Double equal to non-double resulted in non-error.") + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Double(-10), + b: Double(-10), + out: True, + }, + { + a: Double(-10), + b: Double(10), + out: False, + }, + { + a: Double(10), + b: Uint(10), + out: True, + }, + { + a: Double(9), + b: Uint(10), + out: False, + }, + { + a: Double(10), + b: Int(10), + out: True, + }, + { + a: Double(10), + b: Int(-15), + out: False, + }, + { + a: Double(math.NaN()), + b: Int(10), + out: False, + }, + { + a: Double(10), + b: Unknown{2}, + out: Unknown{2}, + }, + } + for _, tc := range tests { + got := tc.a.Equal(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Equal(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } diff --git a/common/types/int.go b/common/types/int.go index 1b45c642..8677c9c8 100644 --- a/common/types/int.go +++ b/common/types/int.go @@ -16,6 +16,7 @@ package types import ( "fmt" + "math" "reflect" "strconv" "time" @@ -72,17 +73,19 @@ func (i Int) Add(other ref.Val) ref.Val { // Compare implements traits.Comparer.Compare. func (i Int) Compare(other ref.Val) ref.Val { - otherInt, ok := other.(Int) - if !ok { + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return NewErr("NaN values cannot be ordered") + } + return compareIntDouble(i, ov) + case Int: + return compareInt(i, ov) + case Uint: + return compareIntUint(i, ov) + default: return MaybeNoSuchOverloadErr(other) } - if i < otherInt { - return IntNegOne - } - if i > otherInt { - return IntOne - } - return IntZero } // ConvertToNative implements ref.Val.ConvertToNative. @@ -208,11 +211,19 @@ func (i Int) Divide(other ref.Val) ref.Val { // Equal implements ref.Val.Equal. func (i Int) Equal(other ref.Val) ref.Val { - otherInt, ok := other.(Int) - if !ok { + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return False + } + return Bool(compareIntDouble(i, ov) == 0) + case Int: + return Bool(i == ov) + case Uint: + return Bool(compareIntUint(i, ov) == 0) + default: return MaybeNoSuchOverloadErr(other) } - return Bool(i == otherInt) } // Modulo implements traits.Modder.Modulo. diff --git a/common/types/int_test.go b/common/types/int_test.go index 07e46dde..cde9a1e7 100644 --- a/common/types/int_test.go +++ b/common/types/int_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "google.golang.org/protobuf/proto" anypb "google.golang.org/protobuf/types/known/anypb" @@ -52,19 +53,93 @@ func TestIntAdd(t *testing.T) { } func TestIntCompare(t *testing.T) { - lt := Int(-1300) - gt := Int(204) - if !lt.Compare(gt).Equal(IntNegOne).(Bool) { - t.Error("Comparison did not yield - 1") - } - if !gt.Compare(lt).Equal(IntOne).(Bool) { - t.Error("Comparison did not yield 1") - } - if !gt.Compare(gt).Equal(IntZero).(Bool) { - t.Error(("Comparison did not yield 0")) + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Int(42), + b: Int(42), + out: IntZero, + }, + { + a: Int(42), + b: Uint(42), + out: IntZero, + }, + { + a: Int(42), + b: Double(42), + out: IntZero, + }, + { + a: Int(-1300), + b: Int(204), + out: IntNegOne, + }, + { + a: Int(-1300), + b: Uint(204), + out: IntNegOne, + }, + { + a: Int(204), + b: Double(204.1), + out: IntNegOne, + }, + { + a: Int(1300), + b: Uint(math.MaxInt64) + 1, + out: IntNegOne, + }, + { + a: Int(204), + b: Uint(205), + out: IntNegOne, + }, + { + a: Int(204), + b: Double(math.MaxInt64) + 1025.0, + out: IntNegOne, + }, + { + a: Int(204), + b: Double(math.NaN()), + out: NewErr("NaN values cannot be ordered"), + }, + { + a: Int(204), + b: Int(-1300), + out: IntOne, + }, + { + a: Int(204), + b: Uint(10), + out: IntOne, + }, + { + a: Int(204), + b: Double(203.9), + out: IntOne, + }, + { + a: Int(204), + b: Double(math.MinInt64) - 1025.0, + out: IntOne, + }, + { + a: Int(1), + b: String("1"), + out: NoSuchOverloadErr(), + }, } - if !IsError(gt.Compare(TypeType)) { - t.Error("Got comparison value, expected error.") + for _, tc := range tests { + comparer := tc.a.(traits.Comparer) + got := comparer.Compare(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Compare(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } @@ -283,8 +358,57 @@ func TestIntDivide(t *testing.T) { } func TestIntEqual(t *testing.T) { - if !IsError(Int(0).Equal(False)) { - t.Error("Int equal to non-int type resulted in non-error.") + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Int(-10), + b: Int(-10), + out: True, + }, + { + a: Int(-10), + b: Int(10), + out: False, + }, + { + a: Int(10), + b: Uint(10), + out: True, + }, + { + a: Int(9), + b: Uint(10), + out: False, + }, + { + a: Int(10), + b: Double(10), + out: True, + }, + { + a: Int(10), + b: Double(-10.5), + out: False, + }, + { + a: Int(10), + b: Double(math.NaN()), + out: False, + }, + { + a: Int(10), + b: Unknown{2}, + out: Unknown{2}, + }, + } + for _, tc := range tests { + got := tc.a.Equal(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Equal(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } diff --git a/common/types/list_test.go b/common/types/list_test.go index 43c2cc27..30e0f52f 100644 --- a/common/types/list_test.go +++ b/common/types/list_test.go @@ -16,6 +16,7 @@ package types import ( "encoding/json" + "math" "reflect" "testing" @@ -48,28 +49,52 @@ func TestBaseListAdd_Error(t *testing.T) { func TestBaseListContains(t *testing.T) { list := NewDynamicList(newTestRegistry(t), []float32{1.0, 2.0, 3.0}) - if list.Contains(Double(5)) != False { - t.Error("List contains did not return false") - } - if list.Contains(Double(3)) != True { - t.Error("List contains did not succeed") - } - list = NewDynamicList(newTestRegistry(t), []interface{}{1.0, 2, 3.0}) - if list.Contains(Int(2)) != True { - t.Error("List contains did not succeed") - } - if list.Contains(Double(3)) != True { - t.Error("List contains did not succeed") - } -} - -func TestBaseListContains_NonBool(t *testing.T) { - list := NewDynamicList(newTestRegistry(t), []interface{}{1.0, 2, 3.0}) - if !IsError(list.Contains(Int(3))) { - t.Error("List contains succeeded with wrong type") - } - if !reflect.DeepEqual(list.Contains(Unknown{1}), Unknown{1}) { - t.Error("list.Contains(unknown) did not return unknown input") + tests := []struct { + in ref.Val + out ref.Val + }{ + { + in: Double(math.NaN()), + out: False, + }, + { + in: Double(5), + out: False, + }, + { + in: Double(3), + out: True, + }, + { + in: Uint(3), + out: True, + }, + { + in: Int(3), + out: True, + }, + { + in: Int(3), + out: True, + }, + { + in: Int(0), + out: False, + }, + { + in: String("3"), + out: NoSuchOverloadErr(), + }, + { + in: Unknown{1}, + out: Unknown{1}, + }, + } + for _, tc := range tests { + got := list.Contains(tc.in) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("list.Contains(%v) returned %v, wanted %v", tc.in, got, tc.out) + } } } @@ -276,7 +301,7 @@ func TestConcatListConvertToNative_Json(t *testing.T) { } } -func TestConcatListConvertToNative_ListInterface(t *testing.T) { +func TestConcatListConvertToNativeListInterface(t *testing.T) { reg := newTestRegistry(t) listA := NewDynamicList(reg, []float32{1.0, 2.0}) listB := NewStringList(reg, []string{"3.0"}) @@ -325,7 +350,7 @@ func TestConcatListContains(t *testing.T) { } } -func TestConcatListContains_NonBool(t *testing.T) { +func TestConcatListContainsNonBool(t *testing.T) { reg := newTestRegistry(t) listA := NewDynamicList(reg, []float32{1.0, 2.0}) listB := NewDynamicList(reg, []string{"3"}) @@ -353,8 +378,8 @@ func TestConcatListEqual(t *testing.T) { t.Errorf("list.Equal(listC) got %v, wanted false", list.Equal(listC)) } listD := reg.NativeToValue([]interface{}{1, 2.0, 3.0}) - if !IsError(list.Equal(listD)) { - t.Errorf("list.Equal(listD) got %v, wanted error", list.Equal(listD)) + if list.Equal(listD) != True { + t.Errorf("list.Equal(listD) got %v, wanted true", list.Equal(listD)) } } diff --git a/common/types/uint.go b/common/types/uint.go index d8d0f24c..c1c89e4b 100644 --- a/common/types/uint.go +++ b/common/types/uint.go @@ -16,6 +16,7 @@ package types import ( "fmt" + "math" "reflect" "strconv" @@ -65,17 +66,19 @@ func (i Uint) Add(other ref.Val) ref.Val { // Compare implements traits.Comparer.Compare. func (i Uint) Compare(other ref.Val) ref.Val { - otherUint, ok := other.(Uint) - if !ok { + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return NewErr("NaN values cannot be ordered") + } + return compareUintDouble(i, ov) + case Int: + return compareUintInt(i, ov) + case Uint: + return compareUint(i, ov) + default: return MaybeNoSuchOverloadErr(other) } - if i < otherUint { - return IntNegOne - } - if i > otherUint { - return IntOne - } - return IntZero } // ConvertToNative implements ref.Val.ConvertToNative. @@ -176,11 +179,19 @@ func (i Uint) Divide(other ref.Val) ref.Val { // Equal implements ref.Val.Equal. func (i Uint) Equal(other ref.Val) ref.Val { - otherUint, ok := other.(Uint) - if !ok { + switch ov := other.(type) { + case Double: + if math.IsNaN(float64(ov)) { + return False + } + return Bool(compareUintDouble(i, ov) == 0) + case Int: + return Bool(compareUintInt(i, ov) == 0) + case Uint: + return Bool(i == ov) + default: return MaybeNoSuchOverloadErr(other) } - return Bool(i == otherUint) } // Modulo implements traits.Modder.Modulo. diff --git a/common/types/uint_test.go b/common/types/uint_test.go index d6ec6033..29b59764 100644 --- a/common/types/uint_test.go +++ b/common/types/uint_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "google.golang.org/protobuf/proto" anypb "google.golang.org/protobuf/types/known/anypb" @@ -45,19 +46,88 @@ func TestUint_Add(t *testing.T) { } func TestUint_Compare(t *testing.T) { - lt := Uint(204) - gt := Uint(1300) - if !lt.Compare(gt).Equal(IntNegOne).(Bool) { - t.Error("Comparison did not yield - 1") - } - if !gt.Compare(lt).Equal(IntOne).(Bool) { - t.Error("Comparison did not yield 1") - } - if !gt.Compare(gt).Equal(IntZero).(Bool) { - t.Error(("Comparison did not yield 0")) + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Uint(42), + b: Uint(42), + out: IntZero, + }, + { + a: Uint(42), + b: Int(42), + out: IntZero, + }, + { + a: Uint(42), + b: Double(42), + out: IntZero, + }, + { + a: Uint(13), + b: Int(204), + out: IntNegOne, + }, + { + a: Uint(13), + b: Uint(204), + out: IntNegOne, + }, + { + a: Uint(204), + b: Double(204.1), + out: IntNegOne, + }, + { + a: Uint(204), + b: Int(205), + out: IntNegOne, + }, + { + a: Uint(204), + b: Double(math.MaxUint64) + 2049.0, + out: IntNegOne, + }, + { + a: Uint(204), + b: Double(math.NaN()), + out: NewErr("NaN values cannot be ordered"), + }, + { + a: Uint(1300), + b: Int(-1), + out: IntOne, + }, + { + a: Uint(204), + b: Uint(13), + out: IntOne, + }, + { + a: Uint(204), + b: Double(203.9), + out: IntOne, + }, + { + a: Uint(204), + b: Double(-1.0), + out: IntOne, + }, + { + a: Uint(1), + b: String("1"), + out: NoSuchOverloadErr(), + }, } - if !IsError(gt.Compare(TypeType)) { - t.Error("Types not comparable") + for _, tc := range tests { + comparer := tc.a.(traits.Comparer) + got := comparer.Compare(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Compare(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } @@ -236,8 +306,57 @@ func TestUint_Divide(t *testing.T) { } func TestUint_Equal(t *testing.T) { - if !IsError(Uint(0).Equal(False)) { - t.Error("Uint equal to non-uint type result in non-error") + tests := []struct { + a ref.Val + b ref.Val + out ref.Val + }{ + { + a: Uint(10), + b: Uint(10), + out: True, + }, + { + a: Uint(10), + b: Int(-10), + out: False, + }, + { + a: Uint(10), + b: Int(10), + out: True, + }, + { + a: Uint(9), + b: Int(10), + out: False, + }, + { + a: Uint(10), + b: Double(10), + out: True, + }, + { + a: Uint(10), + b: Double(-10.5), + out: False, + }, + { + a: Uint(10), + b: Double(math.NaN()), + out: False, + }, + { + a: Uint(10), + b: Unknown{2}, + out: Unknown{2}, + }, + } + for _, tc := range tests { + got := tc.a.Equal(tc.b) + if !reflect.DeepEqual(got, tc.out) { + t.Errorf("%v.Equal(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } } } diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index b03dc6c3..8567e968 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -26,6 +26,8 @@ sh_test( args = [ "$(location @com_google_cel_spec//tests/simple:simple_test)", "--server=$(location //server/main:cel_server)", + # Tests that need to be removed as the spec has changed + "--skip_test=comparisons/eq_literal/eq_mixed_types_error,eq_list_elem_mixed_types_error;ne_literal/ne_mixed_types_error", # Failing conformance tests. "--skip_test=fields/qualified_identifier_resolution/map_key_float,map_key_null,map_value_repeat_key",