diff --git a/ext/native.go b/ext/native.go index d1b78777..0e4bd305 100644 --- a/ext/native.go +++ b/ext/native.go @@ -15,6 +15,7 @@ package ext import ( + "errors" "fmt" "reflect" "strings" @@ -77,12 +78,45 @@ var ( // same advice holds if you are using custom type adapters and type providers. The native type // provider composes over whichever type adapter and provider is configured in the cel.Env at // the time that it is invoked. -func NativeTypes(refTypes ...any) cel.EnvOption { +// +// There is also the possibility to rename the fields of native structs by setting the `cel` tag +// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag` +// option. Here is an example to see it in action: +// +// ```go +// package identity +// +// type Account struct { +// ID int +// OwnerName string `cel:"owner"` +// } +// +// ``` +// +// The `OwnerName` field is now accessible in CEL via `owner`, e.g. `identity.Account{owner: 'bob'}`. +// In case there are duplicated field names in the struct, an error will be returned. +func NativeTypes(args ...any) cel.EnvOption { return func(env *cel.Env) (*cel.Env, error) { - tp, err := newNativeTypeProvider(env.CELTypeAdapter(), env.CELTypeProvider(), refTypes...) + nativeTypes := make([]any, 0, len(args)) + tpOptions := nativeTypeOptions{} + + for _, v := range args { + switch v := v.(type) { + case NativeTypesOption: + err := v(&tpOptions) + if err != nil { + return nil, err + } + default: + nativeTypes = append(nativeTypes, v) + } + } + + tp, err := newNativeTypeProvider(tpOptions, env.CELTypeAdapter(), env.CELTypeProvider(), nativeTypes...) if err != nil { return nil, err } + env, err = cel.CustomTypeAdapter(tp)(env) if err != nil { return nil, err @@ -91,12 +125,29 @@ func NativeTypes(refTypes ...any) cel.EnvOption { } } -func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) { +// NativeTypesOption is a functional interface for configuring handling of native types. +type NativeTypesOption func(*nativeTypeOptions) error + +type nativeTypeOptions struct { + // parseStructTags controls if CEL should support struct field renames, by parsing + // struct field tags. + parseStructTags bool +} + +// ParseStructTags configures if native types field names should be overridable by CEL struct tags. +func ParseStructTags(enabled bool) NativeTypesOption { + return func(ntp *nativeTypeOptions) error { + ntp.parseStructTags = true + return nil + } +} + +func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) { nativeTypes := make(map[string]*nativeType, len(refTypes)) for _, refType := range refTypes { switch rt := refType.(type) { case reflect.Type: - result, err := newNativeTypes(rt) + result, err := newNativeTypes(tpOptions.parseStructTags, rt) if err != nil { return nil, err } @@ -104,7 +155,7 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy nativeTypes[result[idx].TypeName()] = result[idx] } case reflect.Value: - result, err := newNativeTypes(rt.Type()) + result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type()) if err != nil { return nil, err } @@ -119,6 +170,7 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy nativeTypes: nativeTypes, baseAdapter: adapter, baseProvider: provider, + options: tpOptions, }, nil } @@ -126,6 +178,7 @@ type nativeTypeProvider struct { nativeTypes map[string]*nativeType baseAdapter types.Adapter baseProvider types.Provider + options nativeTypeOptions } // EnumValue proxies to the types.Provider configured at the times the NativeTypes @@ -155,6 +208,18 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool return tp.baseProvider.FindStructType(typeName) } +func toFieldName(parseStructTag bool, f reflect.StructField) string { + if !parseStructTag { + return f.Name + } + + if name, found := f.Tag.Lookup("cel"); found { + return name + } + + return f.Name +} + // FindStructFieldNames looks up the type definition first from the native types, then from // the backing provider type set. If found, a set of field names corresponding to the type // will be returned. @@ -163,7 +228,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b fieldCount := t.refType.NumField() fields := make([]string, fieldCount) for i := 0; i < fieldCount; i++ { - fields[i] = t.refType.Field(i).Name + fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i)) } return fields, true } @@ -173,6 +238,22 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b return tp.baseProvider.FindStructFieldNames(typeName) } +// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by +// searching for a matching field tag value or field name. +func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value { + if !parseStructTags { + return target.FieldByName(fieldName) + } + + for i := 0; i < target.Type().NumField(); i++ { + f := target.Type().Field(i) + if toFieldName(parseStructTags, f) == fieldName { + return target.FieldByIndex(f.Index) + } + } + return reflect.Value{} +} + // FindStructFieldType looks up a native type's field definition, and if the type name is not a native // type then proxies to the composed types.Provider func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) { @@ -192,12 +273,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (* Type: celType, IsSet: func(obj any) bool { refVal := reflect.Indirect(reflect.ValueOf(obj)) - refField := refVal.FieldByName(fieldName) + refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName) return !refField.IsZero() }, GetFrom: func(obj any) (any, error) { refVal := reflect.Indirect(reflect.ValueOf(obj)) - refField := refVal.FieldByName(fieldName) + refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName) return getFieldValue(tp, refField), nil }, }, true @@ -259,7 +340,7 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val { time.Time: return tp.baseAdapter.NativeToValue(val) default: - return newNativeObject(tp, val, rawVal) + return tp.newNativeObject(val, rawVal) } default: return tp.baseAdapter.NativeToValue(val) @@ -319,13 +400,13 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) { return nil, false } -func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val { - valType, err := newNativeType(refValue.Type()) +func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val { + valType, err := newNativeType(tp.options.parseStructTags, refValue.Type()) if err != nil { return types.NewErr(err.Error()) } return &nativeObj{ - Adapter: adapter, + Adapter: tp, val: val, valType: valType, refValue: refValue, @@ -372,12 +453,13 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) { if !fieldValue.IsValid() || fieldValue.IsZero() { continue } + fieldName := toFieldName(o.valType.parseStructTags, fieldType) fieldCELVal := o.NativeToValue(fieldValue.Interface()) fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType) if err != nil { return nil, err } - fields[fieldType.Name] = fieldJSONVal.(*structpb.Value) + fields[fieldName] = fieldJSONVal.(*structpb.Value) } return &structpb.Struct{Fields: fields}, nil } @@ -469,8 +551,8 @@ func (o *nativeObj) Value() any { return o.val } -func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) { - nt, err := newNativeType(rawType) +func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) { + nt, err := newNativeType(parseStructTags, rawType) if err != nil { return nil, err } @@ -489,7 +571,7 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) { return } alreadySeen[t.String()] = struct{}{} - nt, ntErr := newNativeType(t) + nt, ntErr := newNativeType(parseStructTags, t) if ntErr != nil { err = ntErr return @@ -505,7 +587,11 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) { return result, err } -func newNativeType(rawType reflect.Type) (*nativeType, error) { +var ( + errDuplicatedFieldName = errors.New("field name already exists in struct") +) + +func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) { refType := rawType if refType.Kind() == reflect.Pointer { refType = refType.Elem() @@ -513,15 +599,34 @@ func newNativeType(rawType reflect.Type) (*nativeType, error) { if !isValidObjectType(refType) { return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType) } + + // Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled. + if parseStructTags { + fieldNames := make(map[string]struct{}) + + for idx := 0; idx < refType.NumField(); idx++ { + field := refType.Field(idx) + fieldName := toFieldName(parseStructTags, field) + + if _, found := fieldNames[fieldName]; found { + return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName) + } else { + fieldNames[fieldName] = struct{}{} + } + } + } + return &nativeType{ - typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), - refType: refType, + typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), + refType: refType, + parseStructTags: parseStructTags, }, nil } type nativeType struct { - typeName string - refType reflect.Type + typeName string + refType reflect.Type + parseStructTags bool } // ConvertToNative implements ref.Val.ConvertToNative. @@ -569,9 +674,26 @@ func (t *nativeType) Value() any { return t.typeName } +// fieldByName returns the corresponding reflect.StructField for the give name either by matching +// field tag or field name. +func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) { + if !t.parseStructTags { + return t.refType.FieldByName(fieldName) + } + + for i := 0; i < t.refType.NumField(); i++ { + f := t.refType.Field(i) + if toFieldName(t.parseStructTags, f) == fieldName { + return f, true + } + } + + return reflect.StructField{}, false +} + // hasField returns whether a field name has a corresponding Golang reflect.StructField func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) { - f, found := t.refType.FieldByName(fieldName) + f, found := t.fieldByName(fieldName) if !found || !f.IsExported() || !isSupportedType(f.Type) { return reflect.StructField{}, false } diff --git a/ext/native_test.go b/ext/native_test.go index 27479cfc..a44bdc04 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -15,6 +15,7 @@ package ext import ( + "errors" "fmt" "reflect" "sort" @@ -38,9 +39,10 @@ import ( func TestNativeTypes(t *testing.T) { var nativeTests = []struct { - expr string - out any - in any + expr string + out any + in any + envOpts []any }{ { expr: `ext.TestAllTypes{ @@ -60,17 +62,20 @@ func TestNativeTypes(t *testing.T) { ext.TestNestedType{ NestedListVal:['goodbye', 'cruel', 'world'], NestedMapVal: {42: true}, + custom_name: 'name', }, ], ArrayVal: [ ext.TestNestedType{ NestedListVal:['goodbye', 'cruel', 'world'], NestedMapVal: {42: true}, + custom_name: 'name', }, ], MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}}, CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}], CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}}, + custom_name: 'name', }`, out: &TestAllTypes{ NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}}, @@ -87,17 +92,85 @@ func TestNativeTypes(t *testing.T) { Uint64Val: uint64(200), ListVal: []*TestNestedType{ { - NestedListVal: []string{"goodbye", "cruel", "world"}, - NestedMapVal: map[int64]bool{42: true}, + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", }, }, ArrayVal: [1]*TestNestedType{{ - NestedListVal: []string{"goodbye", "cruel", "world"}, - NestedMapVal: map[int64]bool{42: true}, + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", }}, MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, CustomSliceVal: []TestNestedSliceType{{Value: "none"}}, CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}}, + CustomName: "name", + }, + envOpts: []any{ParseStructTags(true)}, + }, + { + expr: `ext.TestAllTypes{ + NestedVal: ext.TestNestedType{NestedMapVal: {1: false}}, + BoolVal: true, + BytesVal: b'hello', + DurationVal: duration('5s'), + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: 'hello world', + TimestampVal: timestamp('2011-08-06T01:23:45Z'), + Uint32Val: 100u, + Uint64Val: 200u, + ListVal: [ + ext.TestNestedType{ + NestedListVal:['goodbye', 'cruel', 'world'], + NestedMapVal: {42: true}, + NestedCustomName: 'name', + }, + ], + ArrayVal: [ + ext.TestNestedType{ + NestedListVal:['goodbye', 'cruel', 'world'], + NestedMapVal: {42: true}, + NestedCustomName: 'name', + }, + ], + MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}}, + CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}], + CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}}, + CustomName: 'name', + }`, + out: &TestAllTypes{ + NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}}, + BoolVal: true, + BytesVal: []byte("hello"), + DurationVal: time.Second * 5, + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: "hello world", + TimestampVal: mustParseTime(t, "2011-08-06T01:23:45Z"), + Uint32Val: uint32(100), + Uint64Val: uint64(200), + ListVal: []*TestNestedType{ + { + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", + }, + }, + ArrayVal: [1]*TestNestedType{{ + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", + }}, + MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, + CustomSliceVal: []TestNestedSliceType{{Value: "none"}}, + CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}}, + CustomName: "name", }, }, { @@ -126,6 +199,8 @@ func TestNativeTypes(t *testing.T) { {expr: `ext.TestAllTypes{}.TimestampVal == timestamp(0)`}, {expr: `test.TestAllTypes{}.single_timestamp == timestamp(0)`}, {expr: `[TestAllTypes{BoolVal: true}, TestAllTypes{BoolVal: false}].exists(t, t.BoolVal == true)`}, + {expr: `[TestAllTypes{CustomName: 'Alice'}, TestAllTypes{CustomName: 'Bob'}].exists(t, t.CustomName == 'Alice')`}, + {expr: `[TestAllTypes{custom_name: 'Alice'}, TestAllTypes{custom_name: 'Bob'}].exists(t, t.custom_name == 'Alice')`, envOpts: []any{ParseStructTags(true)}}, { expr: `tests.all(t, t.Int32Val > 17)`, in: map[string]any{ @@ -133,10 +208,10 @@ func TestNativeTypes(t *testing.T) { }, }, } - env := testNativeEnv(t) for i, tst := range nativeTests { tc := tst t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + env := testNativeEnv(t, tc.envOpts...) var asts []*cel.Ast pAst, iss := env.Parse(tc.expr) if iss.Err() != nil { @@ -178,7 +253,7 @@ func TestNativeTypes(t *testing.T) { } func TestNativeFindStructFieldNames(t *testing.T) { - env := testNativeEnv(t) + env := testNativeEnv(t, ParseStructTags(true)) provider := env.CELTypeProvider() tests := []struct { typeName string @@ -186,7 +261,7 @@ func TestNativeFindStructFieldNames(t *testing.T) { }{ { typeName: "ext.TestNestedType", - fields: []string{"NestedListVal", "NestedMapVal"}, + fields: []string{"NestedListVal", "NestedMapVal", "custom_name"}, }, { typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage", @@ -264,8 +339,9 @@ func TestNativeTypesStaticErrors(t *testing.T) { func TestNativeTypesJsonSerialization(t *testing.T) { tests := []struct { - expr string - out string + expr string + out string + additionalEnvOptions []any }{ { expr: `[b'string']`, @@ -287,10 +363,12 @@ func TestNativeTypesJsonSerialization(t *testing.T) { NestedVal: TestNestedType{ NestedListVal: ["first", "second"], }, - StringVal: "string" + StringVal: "string", + CustomName: "name", }`, out: `{ "BoolVal": true, + "CustomName": "name", "DoubleVal": 1.5, "DurationVal": "5s", "FloatVal": 2, @@ -310,11 +388,53 @@ func TestNativeTypesJsonSerialization(t *testing.T) { "StringVal": "string" }`, }, + { + expr: `TestAllTypes{ + BoolVal: true, + DurationVal: duration('5s'), + DoubleVal: 1.5, + FloatVal: 2.0, + Int32Val: 23, + Int64Val: 64, + MapVal: { + 'map-key': ext.TestAllTypes{ + BoolVal: true + } + }, + NestedVal: TestNestedType{ + NestedListVal: ["first", "second"], + }, + StringVal: "string", + custom_name: "name", + }`, + out: `{ + "BoolVal": true, + "DoubleVal": 1.5, + "DurationVal": "5s", + "FloatVal": 2, + "Int32Val": 23, + "Int64Val": 64, + "MapVal": { + "map-key": { + "BoolVal": true + } + }, + "NestedVal": { + "NestedListVal": [ + "first", + "second" + ] + }, + "StringVal": "string", + "custom_name": "name" + }`, + additionalEnvOptions: []any{ParseStructTags(true)}, + }, } - env := testNativeEnv(t) for i, tst := range tests { tc := tst t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + env := testNativeEnv(t, tst.additionalEnvOptions...) ast, iss := env.Compile(tc.expr) if iss.Err() != nil { t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err()) @@ -604,7 +724,7 @@ func TestNativeTypesWithOptional(t *testing.T) { } func TestNativeTypeConvertToType(t *testing.T) { - nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -617,7 +737,7 @@ func TestNativeTypeConvertToType(t *testing.T) { } func TestNativeTypeConvertToNative(t *testing.T) { - nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -628,7 +748,7 @@ func TestNativeTypeConvertToNative(t *testing.T) { } func TestNativeTypeHasTrait(t *testing.T) { - nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -638,7 +758,7 @@ func TestNativeTypeHasTrait(t *testing.T) { } func TestNativeTypeValue(t *testing.T) { - nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -647,8 +767,18 @@ func TestNativeTypeValue(t *testing.T) { } } +func TestNativeStructWithMultileSameFieldNames(t *testing.T) { + _, err := newNativeType(true, reflect.TypeOf(TestStructWithMultipleSameNames{})) + if err == nil { + t.Fatal("newNativeType() did not fail as expected") + } + if !errors.Is(err, errDuplicatedFieldName) { + t.Fatalf("newNativeType() exepected duplicated field name error, but got: %v", err) + } +} + // testEnv initializes the test environment common to all tests. -func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { +func testNativeEnv(t *testing.T, opts ...any) *cel.Env { t.Helper() envOpts := []cel.EnvOption{ cel.Container("ext"), @@ -656,10 +786,23 @@ func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { cel.Types(&proto3pb.TestAllTypes{}), cel.Variable("tests", cel.ListType(cel.ObjectType("ext.TestAllTypes"))), } - envOpts = append(envOpts, opts...) + nativeOpts := []any{ + reflect.ValueOf(&TestAllTypes{}), + } + for _, o := range opts { + switch opt := o.(type) { + case NativeTypesOption: + nativeOpts = append(nativeOpts, opt) + case cel.EnvOption: + envOpts = append(envOpts, opt) + default: + t.Fatalf("invalid option type: %s", reflect.TypeOf(o).Name()) + } + } + envOpts = append(envOpts, NativeTypes( - reflect.ValueOf(&TestAllTypes{}), + nativeOpts..., ), ) env, err := cel.NewEnv(envOpts...) @@ -678,9 +821,15 @@ func mustParseTime(t *testing.T, timestamp string) time.Time { return out } +type TestStructWithMultipleSameNames struct { + Name string + custom_name string `cel:"Name"` +} + type TestNestedType struct { - NestedListVal []string - NestedMapVal map[int64]bool + NestedListVal []string + NestedMapVal map[int64]bool + NestedCustomName string `cel:"custom_name"` } type TestAllTypes struct { @@ -703,6 +852,7 @@ type TestAllTypes struct { PbVal *proto3pb.TestAllTypes CustomSliceVal []TestNestedSliceType CustomMapVal map[string]TestMapVal + CustomName string `cel:"custom_name"` // channel types are not supported UnsupportedVal chan string