diff --git a/LICENSE b/LICENSE index d6456956..2493ed2e 100644 --- a/LICENSE +++ b/LICENSE @@ -200,3 +200,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +=========================================================================== +The common/types/pb/equal.go modification of proto.Equal logic +=========================================================================== +Copyright (c) 2018 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/WORKSPACE b/WORKSPACE index 203840e3..4238e7ad 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -109,7 +109,7 @@ go_repository( # CEL Spec deps go_repository( name = "com_google_cel_spec", - commit = "1a75e8893bb2a1b2f7a63a32a76eda50294837b2", + commit = "b11d0c7144434ceec0fa602ad8391c08f4a591a9", importpath = "github.com/google/cel-spec", ) diff --git a/common/types/object.go b/common/types/object.go index d4efa066..8bc13e82 100644 --- a/common/types/object.go +++ b/common/types/object.go @@ -109,10 +109,7 @@ func (o *protoObj) ConvertToType(typeVal ref.Type) ref.Val { } func (o *protoObj) Equal(other ref.Val) ref.Val { - if o.typeDesc.Name() != other.Type().TypeName() { - return MaybeNoSuchOverloadErr(other) - } - return Bool(proto.Equal(o.value, other.Value().(proto.Message))) + return Bool(pb.Equal(o.value, other.Value().(proto.Message))) } // IsSet tests whether a field which is defined is set to a non-default value. diff --git a/common/types/pb/BUILD.bazel b/common/types/pb/BUILD.bazel index b0c79a7f..f23ac9c0 100644 --- a/common/types/pb/BUILD.bazel +++ b/common/types/pb/BUILD.bazel @@ -10,6 +10,7 @@ go_library( srcs = [ "checked.go", "enum.go", + "equal.go", "file.go", "pb.go", "type.go", @@ -17,6 +18,7 @@ go_library( importpath = "github.com/google/cel-go/common/types/pb", deps = [ "@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/protowire:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", @@ -34,6 +36,7 @@ go_test( name = "go_default_test", size = "small", srcs = [ + "equal_test.go", "file_test.go", "pb_test.go", "type_test.go", diff --git a/common/types/pb/equal.go b/common/types/pb/equal.go new file mode 100644 index 00000000..392a9496 --- /dev/null +++ b/common/types/pb/equal.go @@ -0,0 +1,205 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pb + +import ( + "bytes" + "reflect" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" +) + +// Equal returns whether two proto.Message instances are equal using the following criteria: +// +// - Messages must share the same instance of the type descriptor +// - Known set fields are compared using semantics equality +// - Bytes are compared using bytes.Equal +// - Scalar values are compared with operator == +// - List and map types are equal if they have the same length and all elements are equal +// - Messages are equal if they share the same descriptor and all set fields are equal +// - Unknown fields are compared using byte equality +// - NaN values are not equal to each other +// - google.protobuf.Any values are unpacked before comparison +// - If the type descriptor for a protobuf.Any cannot be found, byte equality is used rather than +// semantic equality. +// +// This method of proto equality mirrors the behavior of the C++ protobuf MessageDifferencer +// whereas the golang proto.Equal implementation mirrors the Java protobuf equals() methods +// behaviors which needed to treat NaN values as equal due to Java semantics. +func Equal(x, y proto.Message) bool { + if x == nil || y == nil { + return x == nil && y == nil + } + xRef := x.ProtoReflect() + yRef := y.ProtoReflect() + return equalMessage(xRef, yRef) +} + +func equalMessage(mx, my protoreflect.Message) bool { + // Note, the original proto.Equal upon which this implementation is based does not specifically handle the + // case when both messages are invalid. It is assumed that the descriptors will be equal and that byte-wise + // comparison will be used, though the semantics of validity are neither clear, nor promised within the + // proto.Equal implementation. + if mx.IsValid() != my.IsValid() || mx.Descriptor() != my.Descriptor() { + return false + } + + // This is an innovation on the default proto.Equal where protobuf.Any values are unpacked before comparison + // as otherwise the Any values are compared by bytes rather than structurally. + if isAny(mx) && isAny(my) { + ax := mx.Interface().(*anypb.Any) + ay := my.Interface().(*anypb.Any) + // If the values are not the same type url, return false. + if ax.GetTypeUrl() != ay.GetTypeUrl() { + return false + } + // If the values are byte equal, then return true. + if bytes.Equal(ax.GetValue(), ay.GetValue()) { + return true + } + // Otherwise fall through to the semantic comparison of the any values. + x, err := ax.UnmarshalNew() + if err != nil { + return false + } + y, err := ay.UnmarshalNew() + if err != nil { + return false + } + // Recursively compare the unwrapped messages to ensure nested Any values are unwrapped accordingly. + return equalMessage(x.ProtoReflect(), y.ProtoReflect()) + } + + // Walk the set fields to determine field-wise equality + nx := 0 + equal := true + mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { + nx++ + equal = my.Has(fd) && equalField(fd, vx, my.Get(fd)) + return equal + }) + if !equal { + return false + } + // Establish the count of set fields on message y + ny := 0 + my.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool { + ny++ + return true + }) + // If the number of set fields is not equal return false. + if nx != ny { + return false + } + + return equalUnknown(mx.GetUnknown(), my.GetUnknown()) +} + +func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { + switch { + case fd.IsMap(): + return equalMap(fd, x.Map(), y.Map()) + case fd.IsList(): + return equalList(fd, x.List(), y.List()) + default: + return equalValue(fd, x, y) + } +} + +func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool { + if x.Len() != y.Len() { + return false + } + equal := true + x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool { + vy := y.Get(k) + equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) + return equal + }) + return equal +} + +func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool { + if x.Len() != y.Len() { + return false + } + for i := x.Len() - 1; i >= 0; i-- { + if !equalValue(fd, x.Get(i), y.Get(i)) { + return false + } + } + return true +} + +func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { + switch fd.Kind() { + case protoreflect.BoolKind: + return x.Bool() == y.Bool() + case protoreflect.EnumKind: + return x.Enum() == y.Enum() + case protoreflect.Int32Kind, protoreflect.Sint32Kind, + protoreflect.Int64Kind, protoreflect.Sint64Kind, + protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind: + return x.Int() == y.Int() + case protoreflect.Uint32Kind, protoreflect.Uint64Kind, + protoreflect.Fixed32Kind, protoreflect.Fixed64Kind: + return x.Uint() == y.Uint() + case protoreflect.FloatKind, protoreflect.DoubleKind: + return x.Float() == y.Float() + case protoreflect.StringKind: + return x.String() == y.String() + case protoreflect.BytesKind: + return bytes.Equal(x.Bytes(), y.Bytes()) + case protoreflect.MessageKind, protoreflect.GroupKind: + return equalMessage(x.Message(), y.Message()) + default: + return x.Interface() == y.Interface() + } +} + +func equalUnknown(x, y protoreflect.RawFields) bool { + lenX := len(x) + lenY := len(y) + if lenX != lenY { + return false + } + if lenX == 0 { + return true + } + if bytes.Equal([]byte(x), []byte(y)) { + return true + } + + mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields) + my := make(map[protoreflect.FieldNumber]protoreflect.RawFields) + for len(x) > 0 { + fnum, _, n := protowire.ConsumeField(x) + mx[fnum] = append(mx[fnum], x[:n]...) + x = x[n:] + } + for len(y) > 0 { + fnum, _, n := protowire.ConsumeField(y) + my[fnum] = append(my[fnum], y[:n]...) + y = y[n:] + } + return reflect.DeepEqual(mx, my) +} + +func isAny(m protoreflect.Message) bool { + return string(m.Descriptor().FullName()) == "google.protobuf.Any" +} diff --git a/common/types/pb/equal_test.go b/common/types/pb/equal_test.go new file mode 100644 index 00000000..da1f591f --- /dev/null +++ b/common/types/pb/equal_test.go @@ -0,0 +1,398 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pb + +import ( + "math" + "testing" + + "google.golang.org/protobuf/proto" + + "google.golang.org/protobuf/types/known/anypb" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestEqual(t *testing.T) { + tests := []struct { + name string + a proto.Message + b proto.Message + out bool + }{ + { + name: "EqualEmptyInstances", + a: &proto3pb.TestAllTypes{}, + b: &proto3pb.TestAllTypes{}, + out: true, + }, + { + name: "NotEqualEmptyInstances", + a: &proto3pb.TestAllTypes{}, + b: &proto3pb.NestedTestAllTypes{}, + out: false, + }, + { + name: "EqualScalarFields", + a: &proto3pb.TestAllTypes{ + SingleBool: true, + SingleBytes: []byte("world"), + SingleDouble: 3.0, + SingleFloat: 1.5, + SingleInt32: 1, + SingleUint64: 1, + SingleString: "hello", + }, + b: &proto3pb.TestAllTypes{ + SingleBool: true, + SingleBytes: []byte("world"), + SingleDouble: 3.0, + SingleFloat: 1.5, + SingleInt32: 1, + SingleUint64: 1, + SingleString: "hello", + }, + out: true, + }, + { + name: "NotEqualFloatNan", + a: &proto3pb.TestAllTypes{ + SingleFloat: float32(math.NaN()), + }, + b: &proto3pb.TestAllTypes{ + SingleFloat: float32(math.NaN()), + }, + out: false, + }, + { + name: "NotEqualDifferentFieldsSet", + a: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + b: &proto3pb.TestAllTypes{}, + out: false, + }, + { + name: "NotEqualDifferentFieldsSetReverse", + a: &proto3pb.TestAllTypes{}, + b: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + out: false, + }, + { + name: "EqualListField", + a: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{1, 2, 3, 4}, + }, + b: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{1, 2, 3, 4}, + }, + out: true, + }, + { + name: "NotEqualListFieldDifferentLength", + a: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{1, 2, 3}, + }, + b: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{1, 2, 3, 4}, + }, + out: false, + }, + { + name: "NotEqualListFieldDifferentContent", + a: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{2, 1}, + }, + b: &proto3pb.TestAllTypes{ + RepeatedInt32: []int32{1, 2}, + }, + out: false, + }, + { + name: "EqualMapField", + a: &proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: { + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + StandaloneEnum: proto3pb.TestAllTypes_BAR, + }, + }, + }, + 2: { + Payload: &proto3pb.TestAllTypes{}, + }, + }, + }, + b: &proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: { + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + StandaloneEnum: proto3pb.TestAllTypes_BAR, + }, + }, + }, + 2: { + Payload: &proto3pb.TestAllTypes{}, + }, + }, + }, + out: true, + }, + { + name: "NotEqualMapFieldDifferentLength", + a: &proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: { + Child: &proto3pb.NestedTestAllTypes{}, + }, + 2: { + Payload: &proto3pb.TestAllTypes{}, + }, + }, + }, + b: &proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: { + Child: &proto3pb.NestedTestAllTypes{}, + }, + }, + }, + out: false, + }, + { + name: "EqualAnyBytes", + a: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + out: true, + }, + { + name: "NotEqualDoublePackedAny", + a: &proto3pb.TestAllTypes{ + SingleAny: doublePackAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: doublePackAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3, 4}, + }), + }, + out: false, + }, + { + name: "NotEqualAnyTypeURL", + a: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.NestedTestAllTypes{}), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.TestAllTypes{}), + }, + out: false, + }, + { + name: "NotEqualAnyFields", + a: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: packAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + out: false, + }, + { + name: "NotEqualAnyDeserializeA", + a: &proto3pb.TestAllTypes{ + SingleAny: badPackAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: badPackAny(t, &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleUint32: 2, + SingleString: "three", + RepeatedInt32: []int32{1, 2, 3}, + }), + }, + out: false, + }, + { + name: "EqualUnknownFields", + a: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + }, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + }, + }), + }, + out: true, + }, + { + name: "NotEqualUnknownFieldsCount", + a: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt32: 1, + SingleFloat: 2.0, + }, + }, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + }, + }), + }, + out: false, + }, + { + name: "NotEqualUnknownFields", + a: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt64: 2, + }, + }, + }), + }, + b: &proto3pb.TestAllTypes{ + SingleAny: misPackAny(t, &proto3pb.NestedTestAllTypes{ + Child: &proto3pb.NestedTestAllTypes{ + Payload: &proto3pb.TestAllTypes{ + SingleInt32: 1, + }, + }, + }), + }, + out: false, + }, + { + name: "NotEqualOneNil", + a: nil, + b: &proto3pb.TestAllTypes{}, + out: false, + }, + { + name: "EqualBothNil", + a: nil, + b: nil, + out: true, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + got := Equal(tc.a, tc.b) + if got != tc.out { + t.Errorf("Equal(%v, %v) got %v, wanted %v", tc.a, tc.b, got, tc.out) + } + }) + } +} + +func packAny(t *testing.T, m proto.Message) *anypb.Any { + t.Helper() + any, err := anypb.New(m) + if err != nil { + t.Fatalf("anypb.New(%v) failed with error: %v", m, err) + } + return any +} + +func doublePackAny(t *testing.T, m proto.Message) *anypb.Any { + t.Helper() + any, err := anypb.New(m) + if err != nil { + t.Fatalf("anypb.New(%v) failed with error: %v", m, err) + } + any, err = anypb.New(any) + if err != nil { + t.Fatalf("anypb.New(%v) failed with error: %v", any, err) + } + return any +} + +func badPackAny(t *testing.T, m proto.Message) *anypb.Any { + t.Helper() + any, err := anypb.New(m) + if err != nil { + t.Fatalf("anypb.New(%v) failed with error: %v", m, err) + } + any.TypeUrl = "type.googleapis.com/BadType" + return any +} + +func misPackAny(t *testing.T, m proto.Message) *anypb.Any { + t.Helper() + any, err := anypb.New(m) + if err != nil { + t.Fatalf("anypb.New(%v) failed with error: %v", m, err) + } + any.TypeUrl = "type.googleapis.com/google.expr.proto3.test.TestAllTypes" + return any +} diff --git a/common/types/pb/file_test.go b/common/types/pb/file_test.go index 9be6971e..1d555556 100644 --- a/common/types/pb/file_test.go +++ b/common/types/pb/file_test.go @@ -1,3 +1,17 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package pb import ( diff --git a/common/types/pb/pb_test.go b/common/types/pb/pb_test.go index 549048f9..e4cba3a8 100644 --- a/common/types/pb/pb_test.go +++ b/common/types/pb/pb_test.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package pb reflects over protocol buffer descriptors to generate objects -// that simplify type, enum, and field lookup. package pb import (