diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index aed82744..5a26d74a 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -24,9 +24,11 @@ go_library( "//common/types:go_default_library", "//common/types/pb:go_default_library", "//common/types/ref:go_default_library", + "//common/types/traits:go_default_library", "//interpreter:go_default_library", "//interpreter/functions:go_default_library", "//parser:go_default_library", + "//test/proto3pb:go_default_library", "@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//reflect/protodesc:go_default_library", @@ -34,6 +36,9 @@ go_library( "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", "@org_golang_google_protobuf//types/descriptorpb:go_default_library", "@org_golang_google_protobuf//types/dynamicpb:go_default_library", + "@org_golang_google_protobuf//types/known/anypb:go_default_library", + "@org_golang_google_protobuf//types/known/durationpb:go_default_library", + "@org_golang_google_protobuf//types/known/timestamppb:go_default_library", ], ) @@ -41,6 +46,7 @@ go_test( name = "go_default_test", srcs = [ "cel_test.go", + "io_test.go", ], data = [ "//cel/testdata:gen_test_fds", @@ -60,5 +66,6 @@ go_test( "//test/proto3pb:go_default_library", "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", "@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//types/known/structpb:go_default_library", ], ) diff --git a/cel/io.go b/cel/io.go index 4df33c06..5b5fc2df 100644 --- a/cel/io.go +++ b/cel/io.go @@ -15,12 +15,21 @@ package cel import ( + "errors" "fmt" + "time" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/parser" + "google.golang.org/protobuf/proto" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + anypb "google.golang.org/protobuf/types/known/anypb" + dpb "google.golang.org/protobuf/types/known/durationpb" + tpb "google.golang.org/protobuf/types/known/timestamppb" ) // CheckedExprToAst converts a checked expression proto message to an Ast. @@ -120,3 +129,178 @@ func AstToString(a *Ast) (string, error) { info := a.SourceInfo() return parser.Unparse(expr, info) } + +// RefValueToValue converts between ref.Val and api.expr.Value. +// The result Value is the serialized proto form. The ref.Val must not be error or unknown. +func RefValueToValue(res ref.Val) (*exprpb.Value, error) { + switch res.Type() { + case types.BoolType: + return &exprpb.Value{ + Kind: &exprpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil + case types.BytesType: + return &exprpb.Value{ + Kind: &exprpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil + case types.DoubleType: + return &exprpb.Value{ + Kind: &exprpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil + case types.IntType: + return &exprpb.Value{ + Kind: &exprpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil + case types.ListType: + l := res.(traits.Lister) + sz := l.Size().(types.Int) + elts := make([]*exprpb.Value, 0, int64(sz)) + for i := types.Int(0); i < sz; i++ { + v, err := RefValueToValue(l.Get(i)) + if err != nil { + return nil, err + } + elts = append(elts, v) + } + return &exprpb.Value{ + Kind: &exprpb.Value_ListValue{ + ListValue: &exprpb.ListValue{Values: elts}}}, nil + case types.MapType: + mapper := res.(traits.Mapper) + sz := mapper.Size().(types.Int) + entries := make([]*exprpb.MapValue_Entry, 0, int64(sz)) + for it := mapper.Iterator(); it.HasNext().(types.Bool); { + k := it.Next() + v := mapper.Get(k) + kv, err := RefValueToValue(k) + if err != nil { + return nil, err + } + vv, err := RefValueToValue(v) + if err != nil { + return nil, err + } + entries = append(entries, &exprpb.MapValue_Entry{Key: kv, Value: vv}) + } + return &exprpb.Value{ + Kind: &exprpb.Value_MapValue{ + MapValue: &exprpb.MapValue{Entries: entries}}}, nil + case types.NullType: + return &exprpb.Value{ + Kind: &exprpb.Value_NullValue{}}, nil + case types.StringType: + return &exprpb.Value{ + Kind: &exprpb.Value_StringValue{StringValue: res.Value().(string)}}, nil + case types.TypeType: + typeName := res.(ref.Type).TypeName() + return &exprpb.Value{Kind: &exprpb.Value_TypeValue{TypeValue: typeName}}, nil + case types.UintType: + return &exprpb.Value{ + Kind: &exprpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil + case types.DurationType: + d, ok := res.Value().(time.Duration) + if !ok { + return nil, errors.New("Expected time.Duration") + } + any, err := anypb.New(dpb.New(d)) + if err != nil { + return nil, err + } + return &exprpb.Value{ + Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil + case types.TimestampType: + t, ok := res.Value().(time.Time) + if !ok { + return nil, errors.New("Expected time.Time") + } + any, err := anypb.New(tpb.New(t)) + if err != nil { + return nil, err + } + return &exprpb.Value{ + Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil + default: + // Object type + pb, ok := res.Value().(proto.Message) + if !ok { + return nil, errors.New("Expected proto message") + } + any, err := anypb.New(pb) + if err != nil { + return nil, err + } + return &exprpb.Value{ + Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil + } +} + +var ( + typeNameToTypeValue = map[string]*types.TypeValue{ + "bool": types.BoolType, + "bytes": types.BytesType, + "double": types.DoubleType, + "null_type": types.NullType, + "int": types.IntType, + "list": types.ListType, + "map": types.MapType, + "string": types.StringType, + "type": types.TypeType, + "uint": types.UintType, + } +) + +// ValueToRefValue converts between exprpb.Value and ref.Val. +func ValueToRefValue(adapter ref.TypeAdapter, v *exprpb.Value) (ref.Val, error) { + switch v.Kind.(type) { + case *exprpb.Value_NullValue: + return types.NullValue, nil + case *exprpb.Value_BoolValue: + return types.Bool(v.GetBoolValue()), nil + case *exprpb.Value_Int64Value: + return types.Int(v.GetInt64Value()), nil + case *exprpb.Value_Uint64Value: + return types.Uint(v.GetUint64Value()), nil + case *exprpb.Value_DoubleValue: + return types.Double(v.GetDoubleValue()), nil + case *exprpb.Value_StringValue: + return types.String(v.GetStringValue()), nil + case *exprpb.Value_BytesValue: + return types.Bytes(v.GetBytesValue()), nil + case *exprpb.Value_ObjectValue: + any := v.GetObjectValue() + msg, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{DiscardUnknown: true}) + if err != nil { + return nil, err + } + return adapter.NativeToValue(msg.(proto.Message)), nil + case *exprpb.Value_MapValue: + m := v.GetMapValue() + entries := make(map[ref.Val]ref.Val) + for _, entry := range m.Entries { + key, err := ValueToRefValue(adapter, entry.Key) + if err != nil { + return nil, err + } + pb, err := ValueToRefValue(adapter, entry.Value) + if err != nil { + return nil, err + } + entries[key] = pb + } + return adapter.NativeToValue(entries), nil + case *exprpb.Value_ListValue: + l := v.GetListValue() + elts := make([]ref.Val, len(l.Values)) + for i, e := range l.Values { + rv, err := ValueToRefValue(adapter, e) + if err != nil { + return nil, err + } + elts[i] = rv + } + return adapter.NativeToValue(elts), nil + case *exprpb.Value_TypeValue: + typeName := v.GetTypeValue() + tv, ok := typeNameToTypeValue[typeName] + if ok { + return tv, nil + } + return types.NewObjectTypeValue(typeName), nil + } + return nil, errors.New("unknown value") +} diff --git a/cel/io_test.go b/cel/io_test.go index 21e905e5..52b27e65 100644 --- a/cel/io_test.go +++ b/cel/io_test.go @@ -15,14 +15,57 @@ package cel import ( + "fmt" "testing" "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/test/proto3pb" "google.golang.org/protobuf/proto" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) +func TestRefValueToValueRoundTrip(t *testing.T) { + tests := []struct { + value interface{} + }{ + {value: types.NullValue}, + {value: types.Bool(true)}, + {value: types.String("abc")}, + {value: types.Double(0.0)}, + {value: types.Bytes(make([]byte, 0, 5))}, + {value: types.Int(0)}, + {value: types.Uint(0)}, + {value: map[int64]int64{1: 1}}, + {value: []interface{}{true, "abc"}}, + {value: &proto3pb.TestAllTypes{SingleString: "abc"}}, + } + + env, err := NewEnv(Types(&proto3pb.TestAllTypes{})) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("[%d]%v", i, tc.value), func(t *testing.T) { + refVal := env.TypeAdapter().NativeToValue(tc.value) + val, err := RefValueToValue(refVal) + if err != nil { + t.Fatalf("RefValueToValue(%v) failed with error: %v", refVal, err) + } + actual, err := ValueToRefValue(env.TypeAdapter(), val) + if err != nil { + t.Fatalf("ValueToRefValue() failed: &v", err) + } + if refVal.Equal(actual) != types.True { + t.Errorf("got val %v, wanted %v", actual, refVal) + } + }) + } +} + func TestAstToProto(t *testing.T) { stdEnv, _ := NewEnv(Declarations( decls.NewVar("a", decls.Dyn), diff --git a/common/types/BUILD.bazel b/common/types/BUILD.bazel index 63842886..1f77c86a 100644 --- a/common/types/BUILD.bazel +++ b/common/types/BUILD.bazel @@ -39,6 +39,9 @@ go_library( "//common/types/traits:go_default_library", "@com_github_stoewer_go_strcase//:go_default_library", "@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", + "@org_golang_google_genproto//googleapis/rpc/status:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 7174aa9a..4629ca43 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -16,7 +16,6 @@ go_library( "//common:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", - "//common/types/traits:go_default_library", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_go_proto", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_go_proto", "@org_golang_google_genproto//googleapis/api/expr/conformance/v1alpha1:go_default_library", @@ -26,8 +25,6 @@ go_library( "@org_golang_google_grpc//status:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/anypb:go_default_library", - "@org_golang_google_protobuf//types/known/durationpb:go_default_library", - "@org_golang_google_protobuf//types/known/timestamppb:go_default_library", ], ) diff --git a/server/server.go b/server/server.go index 613765b2..05232ddb 100644 --- a/server/server.go +++ b/server/server.go @@ -18,25 +18,19 @@ package server import ( "context" "fmt" - "time" "github.com/google/cel-go/cel" "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - "github.com/google/cel-go/common/types/traits" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" test2pb "github.com/google/cel-spec/proto/test/v1/proto2/test_all_types" test3pb "github.com/google/cel-spec/proto/test/v1/proto3/test_all_types" confpb "google.golang.org/genproto/googleapis/api/expr/conformance/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" rpcpb "google.golang.org/genproto/googleapis/rpc/status" - anypb "google.golang.org/protobuf/types/known/anypb" - dpb "google.golang.org/protobuf/types/known/durationpb" - tpb "google.golang.org/protobuf/types/known/timestamppb" ) // ConformanceServer contains the server state. @@ -190,7 +184,7 @@ func RefValueToExprValue(res ref.Val, err error) (*exprpb.ExprValue, error) { }, }}, nil } - v, err := RefValueToValue(res) + v, err := cel.RefValueToValue(res) if err != nil { return nil, err } @@ -198,125 +192,11 @@ func RefValueToExprValue(res ref.Val, err error) (*exprpb.ExprValue, error) { Kind: &exprpb.ExprValue_Value{Value: v}}, nil } -var ( - typeNameToTypeValue = map[string]*types.TypeValue{ - "bool": types.BoolType, - "bytes": types.BytesType, - "double": types.DoubleType, - "null_type": types.NullType, - "int": types.IntType, - "list": types.ListType, - "map": types.MapType, - "string": types.StringType, - "type": types.TypeType, - "uint": types.UintType, - } -) - -// RefValueToValue converts between ref.Val and Value. -// The ref.Val must not be error or unknown. -func RefValueToValue(res ref.Val) (*exprpb.Value, error) { - switch res.Type() { - case types.BoolType: - return &exprpb.Value{ - Kind: &exprpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil - case types.BytesType: - return &exprpb.Value{ - Kind: &exprpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil - case types.DoubleType: - return &exprpb.Value{ - Kind: &exprpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil - case types.IntType: - return &exprpb.Value{ - Kind: &exprpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil - case types.ListType: - l := res.(traits.Lister) - sz := l.Size().(types.Int) - elts := make([]*exprpb.Value, 0, int64(sz)) - for i := types.Int(0); i < sz; i++ { - v, err := RefValueToValue(l.Get(i)) - if err != nil { - return nil, err - } - elts = append(elts, v) - } - return &exprpb.Value{ - Kind: &exprpb.Value_ListValue{ - ListValue: &exprpb.ListValue{Values: elts}}}, nil - case types.MapType: - mapper := res.(traits.Mapper) - sz := mapper.Size().(types.Int) - entries := make([]*exprpb.MapValue_Entry, 0, int64(sz)) - for it := mapper.Iterator(); it.HasNext().(types.Bool); { - k := it.Next() - v := mapper.Get(k) - kv, err := RefValueToValue(k) - if err != nil { - return nil, err - } - vv, err := RefValueToValue(v) - if err != nil { - return nil, err - } - entries = append(entries, &exprpb.MapValue_Entry{Key: kv, Value: vv}) - } - return &exprpb.Value{ - Kind: &exprpb.Value_MapValue{ - MapValue: &exprpb.MapValue{Entries: entries}}}, nil - case types.NullType: - return &exprpb.Value{ - Kind: &exprpb.Value_NullValue{}}, nil - case types.StringType: - return &exprpb.Value{ - Kind: &exprpb.Value_StringValue{StringValue: res.Value().(string)}}, nil - case types.TypeType: - typeName := res.(ref.Type).TypeName() - return &exprpb.Value{Kind: &exprpb.Value_TypeValue{TypeValue: typeName}}, nil - case types.UintType: - return &exprpb.Value{ - Kind: &exprpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil - case types.DurationType: - d, ok := res.Value().(time.Duration) - if !ok { - return nil, status.New(codes.InvalidArgument, "Expected time.Duration").Err() - } - any, err := anypb.New(dpb.New(d)) - if err != nil { - return nil, err - } - return &exprpb.Value{ - Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil - case types.TimestampType: - t, ok := res.Value().(time.Time) - if !ok { - return nil, status.New(codes.InvalidArgument, "Expected time.Time").Err() - } - any, err := anypb.New(tpb.New(t)) - if err != nil { - return nil, err - } - return &exprpb.Value{ - Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil - default: - // Object type - pb, ok := res.Value().(proto.Message) - if !ok { - return nil, status.New(codes.InvalidArgument, "Expected proto message").Err() - } - any, err := anypb.New(pb) - if err != nil { - return nil, err - } - return &exprpb.Value{ - Kind: &exprpb.Value_ObjectValue{ObjectValue: any}}, nil - } -} - // ExprValueToRefValue converts between exprpb.ExprValue and ref.Val. func ExprValueToRefValue(adapter ref.TypeAdapter, ev *exprpb.ExprValue) (ref.Val, error) { switch ev.Kind.(type) { case *exprpb.ExprValue_Value: - return ValueToRefValue(adapter, ev.GetValue()) + return cel.ValueToRefValue(adapter, ev.GetValue()) case *exprpb.ExprValue_Error: // An error ExprValue is a repeated set of rpcpb.Status // messages, with no convention for the status details. @@ -332,63 +212,3 @@ func ExprValueToRefValue(adapter ref.TypeAdapter, ev *exprpb.ExprValue) (ref.Val return nil, status.New(codes.InvalidArgument, "unknown ExprValue kind").Err() } -// ValueToRefValue converts between exprpb.Value and ref.Val. -func ValueToRefValue(adapter ref.TypeAdapter, v *exprpb.Value) (ref.Val, error) { - switch v.Kind.(type) { - case *exprpb.Value_NullValue: - return types.NullValue, nil - case *exprpb.Value_BoolValue: - return types.Bool(v.GetBoolValue()), nil - case *exprpb.Value_Int64Value: - return types.Int(v.GetInt64Value()), nil - case *exprpb.Value_Uint64Value: - return types.Uint(v.GetUint64Value()), nil - case *exprpb.Value_DoubleValue: - return types.Double(v.GetDoubleValue()), nil - case *exprpb.Value_StringValue: - return types.String(v.GetStringValue()), nil - case *exprpb.Value_BytesValue: - return types.Bytes(v.GetBytesValue()), nil - case *exprpb.Value_ObjectValue: - any := v.GetObjectValue() - msg, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{DiscardUnknown: true}) - if err != nil { - return nil, err - } - return adapter.NativeToValue(msg.(proto.Message)), nil - case *exprpb.Value_MapValue: - m := v.GetMapValue() - entries := make(map[ref.Val]ref.Val) - for _, entry := range m.Entries { - key, err := ValueToRefValue(adapter, entry.Key) - if err != nil { - return nil, err - } - pb, err := ValueToRefValue(adapter, entry.Value) - if err != nil { - return nil, err - } - entries[key] = pb - } - return adapter.NativeToValue(entries), nil - case *exprpb.Value_ListValue: - l := v.GetListValue() - elts := make([]ref.Val, len(l.Values)) - for i, e := range l.Values { - rv, err := ValueToRefValue(adapter, e) - if err != nil { - return nil, err - } - elts[i] = rv - } - return adapter.NativeToValue(elts), nil - case *exprpb.Value_TypeValue: - typeName := v.GetTypeValue() - tv, ok := typeNameToTypeValue[typeName] - if ok { - return tv, nil - } - return types.NewObjectTypeValue(typeName), nil - } - return nil, status.New(codes.InvalidArgument, "unknown value").Err() -}