Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FindStructTypeFields support for types.Provider #814

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}

// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}

// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//
Expand Down
4 changes: 4 additions & 0 deletions cel/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ func (p *customCELProvider) FindStructType(typeName string) (*types.Type, bool)
return p.provider.FindStructType(typeName)
}

func (p *customCELProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return p.provider.FindStructFieldNames(typeName)
}

func (p *customCELProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) {
return p.provider.FindStructFieldType(structType, fieldName)
}
Expand Down
21 changes: 21 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type Provider interface {
// Returns false if not found.
FindStructType(structType string) (*Type, bool)

// FindStructFieldNames returns thet field names associated with the type, if the type
// is found.
FindStructFieldNames(structType string) ([]string, bool)

// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
Expand Down Expand Up @@ -173,6 +177,23 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType,
GetFrom: field.GetFrom}, true
}

// FindStructFieldNames returns the set of field names for the given struct type,
// if the type exists in the registry.
func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return []string{}, false
}
fieldMap := msgType.FieldMap()
fields := make([]string, len(fieldMap))
idx := 0
for f := range fieldMap {
fields[idx] = f
idx++
}
return fields, true
}

// FindStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
Expand Down
34 changes: 34 additions & 0 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -132,6 +133,39 @@ func TestRegistryFindStructType(t *testing.T) {
}
}

func TestRegistryFindStructFieldNames(t *testing.T) {
reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{})
tests := []struct {
typeName string
fields []string
}{
{
typeName: "google.api.expr.v1alpha1.Reference",
fields: []string{"name", "overload_id", "value"},
},
{
typeName: "google.api.expr.v1alpha1.Decl",
fields: []string{"name", "ident", "function"},
},
{
typeName: "invalid.TypeName",
fields: []string{},
},
}

for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
fields, _ := reg.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
if !reflect.DeepEqual(fields, tc.fields) {
t.Errorf("got %v, wanted %v", fields, tc.fields)
}
})
}
}

func TestRegistryFindStructFieldType(t *testing.T) {
reg := newTestRegistry(t)
err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile())
Expand Down
18 changes: 18 additions & 0 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
if t, found := tp.nativeTypes[typeName]; found {
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
}
return fields, true
}
if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found {
return celTypeFields, true
}
return tp.baseProvider.FindStructFieldNames(typeName)
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand Down
35 changes: 35 additions & 0 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package ext
import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -162,6 +163,40 @@ func TestNativeTypes(t *testing.T) {
}
}

func TestNativeFindStructFieldNames(t *testing.T) {
env := testNativeEnv(t)
provider := env.CELTypeProvider()
tests := []struct {
typeName string
fields []string
}{
{
typeName: "ext.TestNestedType",
fields: []string{"NestedListVal", "NestedMapVal"},
},
{
typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage",
fields: []string{"bb"},
},
{
typeName: "invalid.TypeName",
fields: []string{},
},
}

for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
fields, _ := provider.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
if !reflect.DeepEqual(fields, tc.fields) {
t.Errorf("got %v, wanted %v", fields, tc.fields)
}
})
}
}

func TestNativeTypesStaticErrors(t *testing.T) {
var nativeTests = []struct {
expr string
Expand Down