Skip to content

Commit

Permalink
fixup! Add support for struct tag name overrides in native types
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickpichler committed May 21, 2024
1 parent 3253aa3 commit 9ffd7ea
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ var (
// ```go
// package identity
//
// type Account struct {
// ID int
// OwnerName string `cel:"owner"`
// }
// type Account struct {
// ID int
// OwnerName string `cel:"owner"`
// }
//
// ```
//
Expand Down Expand Up @@ -125,6 +125,7 @@ func NativeTypes(args ...any) cel.EnvOption {
}
}

// NativeTypesOption is a functional interface for configuring handling of native types.
type NativeTypesOption func(*nativeTypeOptions) error

type nativeTypeOptions struct {
Expand All @@ -133,6 +134,7 @@ type nativeTypeOptions struct {
parseStructTags bool
}

// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
func ParseStructTags(enabled bool) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.parseStructTags = true
Expand Down Expand Up @@ -597,18 +599,23 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)
// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if parseStructTags {
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
} else {
fieldNames[fieldName] = struct{}{}
if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
} else {
fieldNames[fieldName] = struct{}{}
}
}
}

return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
Expand Down Expand Up @@ -670,6 +677,10 @@ func (t *nativeType) Value() any {
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if !t.parseStructTags {
return t.refType.FieldByName(fieldName)
}

for i := 0; i < t.refType.NumField(); i++ {
f := t.refType.Field(i)
if toFieldName(t.parseStructTags, f) == fieldName {
Expand Down

0 comments on commit 9ffd7ea

Please sign in to comment.