Skip to content

Commit

Permalink
internal/fields: adds func for validating struct
Browse files Browse the repository at this point in the history
For the datastore client, we need an efficient way to validate that
a slice of structs field has no field (direct fields or fields of fields)
which is itself also a slice.

This feature should be helpful elsewhere too.

Change-Id: I32c16989afe22c06d6a8f11b30c66dbf4638f90f
Reviewed-on: https://code-review.googlesource.com/9772
Reviewed-by: Jonathan Amsterdam <jba@google.com>
  • Loading branch information
Sarah Adams committed Dec 8, 2016
1 parent c9f70e2 commit 5bfd313
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 21 deletions.
2 changes: 1 addition & 1 deletion bigquery/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func bqTagParser(t reflect.StructTag) (name string, keep bool, other interface{}
return "", true, nil, nil
}

var fieldCache = fields.NewCache(bqTagParser)
var fieldCache = fields.NewCache(bqTagParser, nil)

var (
int64ParamType = &bq.QueryParameterType{Type: "INT64"}
Expand Down
38 changes: 29 additions & 9 deletions internal/fields/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,40 @@ type Field struct {

type ParseTagFunc func(reflect.StructTag) (name string, keep bool, other interface{}, err error)

type ValidateFunc func(reflect.Type) (err error)

// A Cache records information about the fields of struct types.
//
// A Cache is safe for use by multiple goroutines.
type Cache struct {
parseTag ParseTagFunc
validate ValidateFunc
cache atomic.Value // map[reflect.Type][]Field
mu sync.Mutex // used only by writers of cache
}

// NewCache constructs a Cache. Its argument should be a function that accepts
// NewCache constructs a Cache.
// Its first argument should be a function that accepts
// a struct tag and returns four values: an alternative name for the field
// extracted from the tag, a boolean saying whether to keep the field or ignore
// it, additional data that is stored with the field information to avoid
// having to parse the tag again, and an error.
func NewCache(parseTag ParseTagFunc) *Cache {
// Its second argument should be a function that accepts a reflect.Type
// and returns an error if the struct type is invalid in any way.
// For example, it may check that all of the struct field tags are valid, or
// that all fields are of an appropriate type.
func NewCache(parseTag ParseTagFunc, validate ValidateFunc) *Cache {
if parseTag == nil {
parseTag = func(reflect.StructTag) (string, bool, interface{}, error) {
return "", true, nil, nil
}
}
return &Cache{parseTag: parseTag}
if validate == nil {
validate = func(reflect.Type) error {
return nil
}
}
return &Cache{parseTag: parseTag, validate: validate}
}

// A fieldScan represents an item on the fieldByNameFunc scan work list.
Expand All @@ -103,9 +116,6 @@ type fieldScan struct {
// follows the standard Go rules for embedded fields, modified by the presence
// of tags. The result is sorted lexicographically by index.
//
// If not nil, the given parseTag function should extract and return a name
// from the struct tag. This name is used instead of the field's declared name.
//
// These rules apply in the absence of tags:
// Anonymous struct fields are treated as if their inner exported fields were
// fields in the outer struct (embedding). The result includes all fields that
Expand Down Expand Up @@ -169,21 +179,31 @@ func (c *Cache) cachedTypeFields(t reflect.Type) (List, error) {
return cv.fields, cv.err
}

// Validate type
if err := c.validate(t); err != nil {
c.add(t, cacheValue{nil, err})
return nil, err
}

// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f, err := c.typeFields(t)
list := List(f)
c.add(t, cacheValue{list, err})
return list, err
}

// add atomically adds a new key-value pair to the cache
func (c *Cache) add(k reflect.Type, v cacheValue) {
c.mu.Lock()
mp, _ = c.cache.Load().(map[reflect.Type]cacheValue)
mp, _ := c.cache.Load().(map[reflect.Type]cacheValue)
newM := make(map[reflect.Type]cacheValue, len(mp)+1)
for k, v := range mp {
newM[k] = v
}
newM[t] = cacheValue{list, err}
newM[k] = v
c.cache.Store(newM)
c.mu.Unlock()
return list, err
}

func (c *Cache) typeFields(t reflect.Type) ([]Field, error) {
Expand Down
57 changes: 46 additions & 11 deletions internal/fields/fields_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func tfield(name string, tval interface{}, index ...int) *Field {
}

func TestFieldsNoTags(t *testing.T) {
c := NewCache(nil)
c := NewCache(nil, nil)
got, err := c.Fields(reflect.TypeOf(S1{}))
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestAgainstJSONEncodingNoTags(t *testing.T) {
jsonRoundTrip(t, s1, &want)
var got S1
got.embed2 = &embed2{} // need this because reflection won't create it
fields, err := NewCache(nil).Fields(reflect.TypeOf(got))
fields, err := NewCache(nil, nil).Fields(reflect.TypeOf(got))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -180,8 +180,22 @@ func jsonTagParser(t reflect.StructTag) (name string, keep bool, other interface
return parts[0], true, other, nil
}

func validateFunc(t reflect.Type) (err error) {
if t.Kind() != reflect.Struct {
return errors.New("non-struct type used")
}

for i := 0; i < t.NumField(); i++ {
if t.Field(i).Type.Kind() == reflect.Slice {
return fmt.Errorf("slice field found at field %s on struct %s", t.Field(i).Name, t.Name())
}
}

return nil
}

func TestFieldsWithTags(t *testing.T) {
got, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S2{}))
got, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S2{}))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -219,7 +233,7 @@ func TestAgainstJSONEncodingWithTags(t *testing.T) {
var want S2
jsonRoundTrip(t, s2, &want)
var got S2
fields, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(got))
fields, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(got))
if err != nil {
t.Fatal(err)
}
Expand All @@ -243,7 +257,7 @@ func TestUnexportedAnonymousNonStruct(t *testing.T) {
}
)

got, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S{}))
got, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -262,7 +276,7 @@ func TestUnexportedAnonymousStruct(t *testing.T) {
s1 `json:"Y"`
}
)
got, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S2{}))
got, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S2{}))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -304,7 +318,7 @@ func TestIgnore(t *testing.T) {
type S struct {
X int `json:"-"`
}
got, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S{}))
got, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -317,7 +331,7 @@ func TestParsedTag(t *testing.T) {
type S struct {
X int `json:"name,omitempty"`
}
got, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S{}))
got, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -330,6 +344,27 @@ func TestParsedTag(t *testing.T) {
}
}

func TestValidateFunc(t *testing.T) {
type MyInvalidStruct struct {
A string
B []int
}

_, err := NewCache(nil, validateFunc).Fields(reflect.TypeOf(MyInvalidStruct{}))
if err == nil {
t.Fatal("expected error, got nil")
}

type MyValidStruct struct {
A string
B int
}
_, err = NewCache(nil, validateFunc).Fields(reflect.TypeOf(MyValidStruct{}))
if err != nil {
t.Fatalf("expected nil, got error: %s\n", err)
}
}

func compareFields(got []Field, want []*Field) (msg string, ok bool) {
if len(got) != len(want) {
return fmt.Sprintf("got %d fields, want %d", len(got), len(want)), false
Expand Down Expand Up @@ -393,7 +428,7 @@ type S4 struct {
}

func TestMatchingField(t *testing.T) {
fields, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S3{}))
fields, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S3{}))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -435,7 +470,7 @@ func TestAgainstJSONMatchingField(t *testing.T) {
var want S3
jsonRoundTrip(t, s3, &want)
v := reflect.ValueOf(want)
fields, err := NewCache(jsonTagParser).Fields(reflect.TypeOf(S3{}))
fields, err := NewCache(jsonTagParser, nil).Fields(reflect.TypeOf(S3{}))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -469,7 +504,7 @@ func TestTagErrors(t *testing.T) {
return "", false, nil, errors.New("error")
}
return s, true, nil, nil
})
}, nil)

type T struct {
X int `f:"ok"`
Expand Down

0 comments on commit 5bfd313

Please sign in to comment.