Skip to content

Commit

Permalink
Numeric in/equality and comparisons across numeric types (#473)
Browse files Browse the repository at this point in the history
* Numeric in/equality runtime changes to support comparisons of numbers across types
* Fix for nan check
  • Loading branch information
TristonianJones authored Dec 11, 2021
1 parent 94e74ac commit fa20ec8
Show file tree
Hide file tree
Showing 10 changed files with 633 additions and 105 deletions.
1 change: 1 addition & 0 deletions common/types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ go_library(
"any_value.go",
"bool.go",
"bytes.go",
"compare.go",
"double.go",
"duration.go",
"err.go",
Expand Down
95 changes: 95 additions & 0 deletions common/types/compare.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package types

import (
"math"
)

func compareDoubleInt(d Double, i Int) Int {
if d < math.MinInt64 {
return IntNegOne
}
if d > math.MaxInt64 {
return IntOne
}
return compareDouble(d, Double(i))
}

func compareIntDouble(i Int, d Double) Int {
return -compareDoubleInt(d, i)
}

func compareDoubleUint(d Double, u Uint) Int {
if d < 0 {
return IntNegOne
}
if d > math.MaxUint64 {
return IntOne
}
return compareDouble(d, Double(u))
}

func compareUintDouble(u Uint, d Double) Int {
return -compareDoubleUint(d, u)
}

func compareIntUint(i Int, u Uint) Int {
if i < 0 || u > math.MaxInt64 {
return IntNegOne
}
cmp := i - Int(u)
if cmp < 0 {
return IntNegOne
}
if cmp > 0 {
return IntOne
}
return IntZero
}

func compareUintInt(u Uint, i Int) Int {
return -compareIntUint(i, u)
}

func compareDouble(a, b Double) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}

func compareInt(a, b Int) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}

func compareUint(a, b Uint) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}
42 changes: 29 additions & 13 deletions common/types/double.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package types

import (
"fmt"
"math"
"reflect"

"github.com/google/cel-go/common/types/ref"
Expand Down Expand Up @@ -58,17 +59,22 @@ func (d Double) Add(other ref.Val) ref.Val {

// Compare implements traits.Comparer.Compare.
func (d Double) Compare(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return MaybeNoSuchOverloadErr(other)
if math.IsNaN(float64(d)) {
return NewErr("NaN values cannot be ordered")
}
if d < otherDouble {
return IntNegOne
}
if d > otherDouble {
return IntOne
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareDouble(d, ov)
case Int:
return compareDoubleInt(d, ov)
case Uint:
return compareDoubleUint(d, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
return IntZero
}

// ConvertToNative implements ref.Val.ConvertToNative.
Expand Down Expand Up @@ -158,12 +164,22 @@ func (d Double) Divide(other ref.Val) ref.Val {

// Equal implements ref.Val.Equal.
func (d Double) Equal(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
if math.IsNaN(float64(d)) {
return False
}
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(d == ov)
case Int:
return Bool(compareDoubleInt(d, ov) == 0)
case Uint:
return Bool(compareDoubleUint(d, ov) == 0)
default:
return MaybeNoSuchOverloadErr(other)
}
// TODO: Handle NaNs properly.
return Bool(d == otherDouble)
}

// Multiply implements traits.Multiplier.Multiply.
Expand Down
152 changes: 138 additions & 14 deletions common/types/double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"

anypb "google.golang.org/protobuf/types/known/anypb"
Expand All @@ -39,19 +40,93 @@ func TestDoubleAdd(t *testing.T) {
}

func TestDoubleCompare(t *testing.T) {
lt := Double(-1300)
gt := Double(204)
if !lt.Compare(gt).Equal(IntNegOne).(Bool) {
t.Error("Comparison did not yield - 1")
}
if !gt.Compare(lt).Equal(IntOne).(Bool) {
t.Error("Comparison did not yield 1")
}
if !gt.Compare(gt).Equal(IntZero).(Bool) {
t.Error(("Comparison did not yield 0"))
tests := []struct {
a ref.Val
b ref.Val
out ref.Val
}{
{
a: Double(42),
b: Double(42),
out: IntZero,
},
{
a: Double(42),
b: Uint(42),
out: IntZero,
},
{
a: Double(42),
b: Int(42),
out: IntZero,
},
{
a: Double(-1300),
b: Double(204),
out: IntNegOne,
},
{
a: Double(-1300),
b: Uint(204),
out: IntNegOne,
},
{
a: Double(203.9),
b: Int(204),
out: IntNegOne,
},
{
a: Double(1300),
b: Uint(math.MaxInt64) + 1,
out: IntNegOne,
},
{
a: Double(204),
b: Uint(205),
out: IntNegOne,
},
{
a: Double(204),
b: Double(math.MaxInt64) + 1025.0,
out: IntNegOne,
},
{
a: Double(204),
b: Double(math.NaN()),
out: NewErr("NaN values cannot be ordered"),
},
{
a: Double(math.NaN()),
b: Double(204),
out: NewErr("NaN values cannot be ordered"),
},
{
a: Double(204),
b: Double(-1300),
out: IntOne,
},
{
a: Double(204),
b: Uint(10),
out: IntOne,
},
{
a: Double(204.1),
b: Int(204),
out: IntOne,
},
{
a: Double(1),
b: String("1"),
out: NoSuchOverloadErr(),
},
}
if !IsError(gt.Compare(TypeType)) {
t.Error("Types not comparable")
for _, tc := range tests {
comparer := tc.a.(traits.Comparer)
got := comparer.Compare(tc.b)
if !reflect.DeepEqual(got, tc.out) {
t.Errorf("%v.Compare(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out)
}
}
}

Expand Down Expand Up @@ -291,8 +366,57 @@ func TestDoubleDivide(t *testing.T) {
}

func TestDoubleEqual(t *testing.T) {
if !IsError(Double(0).Equal(False)) {
t.Error("Double equal to non-double resulted in non-error.")
tests := []struct {
a ref.Val
b ref.Val
out ref.Val
}{
{
a: Double(-10),
b: Double(-10),
out: True,
},
{
a: Double(-10),
b: Double(10),
out: False,
},
{
a: Double(10),
b: Uint(10),
out: True,
},
{
a: Double(9),
b: Uint(10),
out: False,
},
{
a: Double(10),
b: Int(10),
out: True,
},
{
a: Double(10),
b: Int(-15),
out: False,
},
{
a: Double(math.NaN()),
b: Int(10),
out: False,
},
{
a: Double(10),
b: Unknown{2},
out: Unknown{2},
},
}
for _, tc := range tests {
got := tc.a.Equal(tc.b)
if !reflect.DeepEqual(got, tc.out) {
t.Errorf("%v.Equal(%v) got %v, wanted %v", tc.a, tc.b, got, tc.out)
}
}
}

Expand Down
Loading

0 comments on commit fa20ec8

Please sign in to comment.