From 8a1694bcb5ed548c0556259d8f38ed622e75a749 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 18 Aug 2023 14:43:07 -0700 Subject: [PATCH] FindStructTypeFields support for types.Provider --- cel/env.go | 7 +++++++ cel/env_test.go | 4 ++++ common/types/provider.go | 21 +++++++++++++++++++++ common/types/provider_test.go | 34 ++++++++++++++++++++++++++++++++++ ext/native.go | 18 ++++++++++++++++++ ext/native_test.go | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+) diff --git a/cel/env.go b/cel/env.go index 96b61ad7..473604bc 100644 --- a/cel/env.go +++ b/cel/env.go @@ -773,6 +773,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b return nil, false } +// FindStructFieldNames returns an empty set of field for the interop provider. +// +// To inspect the field names, migrate to a `types.Provider` implementation. +func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return []string{}, false +} + // FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field // name, if one exists. // diff --git a/cel/env_test.go b/cel/env_test.go index 66372796..ed1f836e 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -385,6 +385,10 @@ func (p *customCELProvider) FindStructType(typeName string) (*types.Type, bool) return p.provider.FindStructType(typeName) } +func (p *customCELProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return p.provider.FindStructFieldNames(typeName) +} + func (p *customCELProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { return p.provider.FindStructFieldType(structType, fieldName) } diff --git a/common/types/provider.go b/common/types/provider.go index e9bda552..d301aa38 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -54,6 +54,10 @@ type Provider interface { // Returns false if not found. FindStructType(structType string) (*Type, bool) + // FindStructFieldNames returns thet field names associated with the type, if the type + // is found. + FindStructFieldNames(structType string) ([]string, bool) + // FieldStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. FindStructFieldType(structType, fieldName string) (*FieldType, bool) @@ -173,6 +177,23 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, GetFrom: field.GetFrom}, true } +// FindStructFieldNames returns the set of field names for the given struct type, +// if the type exists in the registry. +func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) { + msgType, found := p.pbdb.DescribeType(structType) + if !found { + return []string{}, false + } + fieldMap := msgType.FieldMap() + fields := make([]string, len(fieldMap)) + idx := 0 + for f := range fieldMap { + fields[idx] = f + idx++ + } + return fields, true +} + // FindStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) { diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 3e93afc5..56f15290 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "reflect" + "sort" "strings" "testing" "time" @@ -132,6 +133,39 @@ func TestRegistryFindStructType(t *testing.T) { } } +func TestRegistryFindStructFieldNames(t *testing.T) { + reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{}) + tests := []struct { + typeName string + fields []string + }{ + { + typeName: "google.api.expr.v1alpha1.Reference", + fields: []string{"name", "overload_id", "value"}, + }, + { + typeName: "google.api.expr.v1alpha1.Decl", + fields: []string{"name", "ident", "function"}, + }, + { + typeName: "invalid.TypeName", + fields: []string{}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { + fields, _ := reg.FindStructFieldNames(tc.typeName) + sort.Strings(fields) + sort.Strings(tc.fields) + if !reflect.DeepEqual(fields, tc.fields) { + t.Errorf("got %v, wanted %v", fields, tc.fields) + } + }) + } +} + func TestRegistryFindStructFieldType(t *testing.T) { reg := newTestRegistry(t) err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile()) diff --git a/ext/native.go b/ext/native.go index 0b5fc38c..0c2cd52f 100644 --- a/ext/native.go +++ b/ext/native.go @@ -151,6 +151,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool return tp.baseProvider.FindStructType(typeName) } +// FindStructFieldNames looks up the type definition first from the native types, then from +// the backing provider type set. If found, a set of field names corresponding to the type +// will be returned. +func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + if t, found := tp.nativeTypes[typeName]; found { + fieldCount := t.refType.NumField() + fields := make([]string, fieldCount) + for i := 0; i < fieldCount; i++ { + fields[i] = t.refType.Field(i).Name + } + return fields, true + } + if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found { + return celTypeFields, true + } + return tp.baseProvider.FindStructFieldNames(typeName) +} + // FindStructFieldType looks up a native type's field definition, and if the type name is not a native // type then proxies to the composed types.Provider func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) { diff --git a/ext/native_test.go b/ext/native_test.go index 21a67879..ead7bc1c 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -17,6 +17,7 @@ package ext import ( "fmt" "reflect" + "sort" "strings" "testing" "time" @@ -162,6 +163,40 @@ func TestNativeTypes(t *testing.T) { } } +func TestNativeFindStructFieldNames(t *testing.T) { + env := testNativeEnv(t) + provider := env.CELTypeProvider() + tests := []struct { + typeName string + fields []string + }{ + { + typeName: "ext.TestNestedType", + fields: []string{"NestedListVal", "NestedMapVal"}, + }, + { + typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage", + fields: []string{"bb"}, + }, + { + typeName: "invalid.TypeName", + fields: []string{}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { + fields, _ := provider.FindStructFieldNames(tc.typeName) + sort.Strings(fields) + sort.Strings(tc.fields) + if !reflect.DeepEqual(fields, tc.fields) { + t.Errorf("got %v, wanted %v", fields, tc.fields) + } + }) + } +} + func TestNativeTypesStaticErrors(t *testing.T) { var nativeTests = []struct { expr string