Skip to content

Commit

Permalink
FindStructTypeFields support for types.Provider (#814)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Aug 18, 2023
1 parent 26aa367 commit eaebecb
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}

// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}

// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//
Expand Down
4 changes: 4 additions & 0 deletions cel/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ func (p *customCELProvider) FindStructType(typeName string) (*types.Type, bool)
return p.provider.FindStructType(typeName)
}

func (p *customCELProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return p.provider.FindStructFieldNames(typeName)
}

func (p *customCELProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) {
return p.provider.FindStructFieldType(structType, fieldName)
}
Expand Down
21 changes: 21 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type Provider interface {
// Returns false if not found.
FindStructType(structType string) (*Type, bool)

// FindStructFieldNames returns thet field names associated with the type, if the type
// is found.
FindStructFieldNames(structType string) ([]string, bool)

// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
Expand Down Expand Up @@ -173,6 +177,23 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType,
GetFrom: field.GetFrom}, true
}

// FindStructFieldNames returns the set of field names for the given struct type,
// if the type exists in the registry.
func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return []string{}, false
}
fieldMap := msgType.FieldMap()
fields := make([]string, len(fieldMap))
idx := 0
for f := range fieldMap {
fields[idx] = f
idx++
}
return fields, true
}

// FindStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
Expand Down
34 changes: 34 additions & 0 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -132,6 +133,39 @@ func TestRegistryFindStructType(t *testing.T) {
}
}

func TestRegistryFindStructFieldNames(t *testing.T) {
reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{})
tests := []struct {
typeName string
fields []string
}{
{
typeName: "google.api.expr.v1alpha1.Reference",
fields: []string{"name", "overload_id", "value"},
},
{
typeName: "google.api.expr.v1alpha1.Decl",
fields: []string{"name", "ident", "function"},
},
{
typeName: "invalid.TypeName",
fields: []string{},
},
}

for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
fields, _ := reg.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
if !reflect.DeepEqual(fields, tc.fields) {
t.Errorf("got %v, wanted %v", fields, tc.fields)
}
})
}
}

func TestRegistryFindStructFieldType(t *testing.T) {
reg := newTestRegistry(t)
err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile())
Expand Down
18 changes: 18 additions & 0 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
if t, found := tp.nativeTypes[typeName]; found {
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
}
return fields, true
}
if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found {
return celTypeFields, true
}
return tp.baseProvider.FindStructFieldNames(typeName)
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand Down
35 changes: 35 additions & 0 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package ext
import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -162,6 +163,40 @@ func TestNativeTypes(t *testing.T) {
}
}

func TestNativeFindStructFieldNames(t *testing.T) {
env := testNativeEnv(t)
provider := env.CELTypeProvider()
tests := []struct {
typeName string
fields []string
}{
{
typeName: "ext.TestNestedType",
fields: []string{"NestedListVal", "NestedMapVal"},
},
{
typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage",
fields: []string{"bb"},
},
{
typeName: "invalid.TypeName",
fields: []string{},
},
}

for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
fields, _ := provider.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
if !reflect.DeepEqual(fields, tc.fields) {
t.Errorf("got %v, wanted %v", fields, tc.fields)
}
})
}
}

func TestNativeTypesStaticErrors(t *testing.T) {
var nativeTests = []struct {
expr string
Expand Down

0 comments on commit eaebecb

Please sign in to comment.