Skip to content

Commit

Permalink
Add support for struct tag name overrides in native types (#941)
Browse files Browse the repository at this point in the history
* Add support for struct tag name overrides in native types

In order to customize how to fields of native types can be accessed via
CEL, struct tag support has been added. It is now possible to provide
a struct field tag, that will act as the name for the corresponding
field

Here an example

```
type Person struct {
  Name string
  Age  int `cel:"age"`
}
```

and here is how to access the `Age` field from CEL:
```
person.age
```

* fixup! Add support for struct tag name overrides in native types

Make struct tag parsing opt-in

* fixup! Add support for struct tag name overrides in native types

Add docs

* fixup! Add support for struct tag name overrides in native types
  • Loading branch information
patrickpichler authored May 21, 2024
1 parent 3841093 commit f7b83ae
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 45 deletions.
166 changes: 144 additions & 22 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ext

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -77,12 +78,45 @@ var (
// same advice holds if you are using custom type adapters and type providers. The native type
// provider composes over whichever type adapter and provider is configured in the cel.Env at
// the time that it is invoked.
func NativeTypes(refTypes ...any) cel.EnvOption {
//
// There is also the possibility to rename the fields of native structs by setting the `cel` tag
// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag`
// option. Here is an example to see it in action:
//
// ```go
// package identity
//
// type Account struct {
// ID int
// OwnerName string `cel:"owner"`
// }
//
// ```
//
// The `OwnerName` field is now accessible in CEL via `owner`, e.g. `identity.Account{owner: 'bob'}`.
// In case there are duplicated field names in the struct, an error will be returned.
func NativeTypes(args ...any) cel.EnvOption {
return func(env *cel.Env) (*cel.Env, error) {
tp, err := newNativeTypeProvider(env.CELTypeAdapter(), env.CELTypeProvider(), refTypes...)
nativeTypes := make([]any, 0, len(args))
tpOptions := nativeTypeOptions{}

for _, v := range args {
switch v := v.(type) {
case NativeTypesOption:
err := v(&tpOptions)
if err != nil {
return nil, err
}
default:
nativeTypes = append(nativeTypes, v)
}
}

tp, err := newNativeTypeProvider(tpOptions, env.CELTypeAdapter(), env.CELTypeProvider(), nativeTypes...)
if err != nil {
return nil, err
}

env, err = cel.CustomTypeAdapter(tp)(env)
if err != nil {
return nil, err
Expand All @@ -91,20 +125,37 @@ func NativeTypes(refTypes ...any) cel.EnvOption {
}
}

func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
// NativeTypesOption is a functional interface for configuring handling of native types.
type NativeTypesOption func(*nativeTypeOptions) error

type nativeTypeOptions struct {
// parseStructTags controls if CEL should support struct field renames, by parsing
// struct field tags.
parseStructTags bool
}

// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
func ParseStructTags(enabled bool) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.parseStructTags = true
return nil
}
}

func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
nativeTypes := make(map[string]*nativeType, len(refTypes))
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
result, err := newNativeTypes(rt)
result, err := newNativeTypes(tpOptions.parseStructTags, rt)
if err != nil {
return nil, err
}
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
result, err := newNativeTypes(rt.Type())
result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type())
if err != nil {
return nil, err
}
Expand All @@ -119,13 +170,15 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy
nativeTypes: nativeTypes,
baseAdapter: adapter,
baseProvider: provider,
options: tpOptions,
}, nil
}

type nativeTypeProvider struct {
nativeTypes map[string]*nativeType
baseAdapter types.Adapter
baseProvider types.Provider
options nativeTypeOptions
}

// EnumValue proxies to the types.Provider configured at the times the NativeTypes
Expand Down Expand Up @@ -155,6 +208,18 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

func toFieldName(parseStructTag bool, f reflect.StructField) string {
if !parseStructTag {
return f.Name
}

if name, found := f.Tag.Lookup("cel"); found {
return name
}

return f.Name
}

// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
Expand All @@ -163,7 +228,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i))
}
return fields, true
}
Expand All @@ -173,6 +238,22 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
return tp.baseProvider.FindStructFieldNames(typeName)
}

// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
// searching for a matching field tag value or field name.
func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value {
if !parseStructTags {
return target.FieldByName(fieldName)
}

for i := 0; i < target.Type().NumField(); i++ {
f := target.Type().Field(i)
if toFieldName(parseStructTags, f) == fieldName {
return target.FieldByIndex(f.Index)
}
}
return reflect.Value{}
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand All @@ -192,12 +273,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
return getFieldValue(tp, refField), nil
},
}, true
Expand Down Expand Up @@ -259,7 +340,7 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
time.Time:
return tp.baseAdapter.NativeToValue(val)
default:
return newNativeObject(tp, val, rawVal)
return tp.newNativeObject(val, rawVal)
}
default:
return tp.baseAdapter.NativeToValue(val)
Expand Down Expand Up @@ -319,13 +400,13 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
return nil, false
}

func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(tp.options.parseStructTags, refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
return &nativeObj{
Adapter: adapter,
Adapter: tp,
val: val,
valType: valType,
refValue: refValue,
Expand Down Expand Up @@ -372,12 +453,13 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(o.valType.parseStructTags, fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJSONVal.(*structpb.Value)
fields[fieldName] = fieldJSONVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
Expand Down Expand Up @@ -469,8 +551,8 @@ func (o *nativeObj) Value() any {
return o.val
}

func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(rawType)
func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(parseStructTags, rawType)
if err != nil {
return nil, err
}
Expand All @@ -489,7 +571,7 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(t)
nt, ntErr := newNativeType(parseStructTags, t)
if ntErr != nil {
err = ntErr
return
Expand All @@ -505,23 +587,46 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
return result, err
}

func newNativeType(rawType reflect.Type) (*nativeType, error) {
var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)

func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
}
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}

// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if parseStructTags {
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
} else {
fieldNames[fieldName] = struct{}{}
}
}
}

return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
parseStructTags: parseStructTags,
}, nil
}

type nativeType struct {
typeName string
refType reflect.Type
typeName string
refType reflect.Type
parseStructTags bool
}

// ConvertToNative implements ref.Val.ConvertToNative.
Expand Down Expand Up @@ -569,9 +674,26 @@ func (t *nativeType) Value() any {
return t.typeName
}

// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if !t.parseStructTags {
return t.refType.FieldByName(fieldName)
}

for i := 0; i < t.refType.NumField(); i++ {
f := t.refType.Field(i)
if toFieldName(t.parseStructTags, f) == fieldName {
return f, true
}
}

return reflect.StructField{}, false
}

// hasField returns whether a field name has a corresponding Golang reflect.StructField
func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
f, found := t.refType.FieldByName(fieldName)
f, found := t.fieldByName(fieldName)
if !found || !f.IsExported() || !isSupportedType(f.Type) {
return reflect.StructField{}, false
}
Expand Down
Loading

0 comments on commit f7b83ae

Please sign in to comment.