Skip to content

Commit

Permalink
Enable gorm filtering by Identifier (#98)
Browse files Browse the repository at this point in the history
* Enable gorm filtering by Identifier

* Fix imports
  • Loading branch information
Daniil Kukharau authored Aug 10, 2018
1 parent 923006b commit 3237ef0
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ f := &query.Filtering{}
s := &query.Sorting{}
p := &query.Pagination{}
fs := &query.FieldSelection{}
gormDB, err = ApplyCollectionOperators(gormDB, f, s, p, fs)
gormDB, err = ApplyCollectionOperators(ctx, gormDB, &PersonORM{}, &Person{}, f, s, p, fs)
if err != nil {
...
}
Expand Down
35 changes: 19 additions & 16 deletions gorm/collection_operators.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
package gorm

import (
"context"
"fmt"
"strings"

"github.com/infobloxopen/atlas-app-toolkit/query"
"github.com/golang/protobuf/proto"
"github.com/jinzhu/gorm"

"github.com/infobloxopen/atlas-app-toolkit/query"
)

// ApplyCollectionOperators applies collection operators to gorm instance db.
func ApplyCollectionOperators(db *gorm.DB, obj interface{}, f *query.Filtering, s *query.Sorting, p *query.Pagination, fs *query.FieldSelection) (*gorm.DB, error) {
db, fAssocToJoin, err := ApplyFiltering(db, f, obj)
func ApplyCollectionOperators(ctx context.Context, db *gorm.DB, obj interface{}, pb proto.Message, f *query.Filtering, s *query.Sorting, p *query.Pagination, fs *query.FieldSelection) (*gorm.DB, error) {
db, fAssocToJoin, err := ApplyFiltering(ctx, db, f, obj, pb)
if err != nil {
return nil, err
}

db, sAssocToJoin, err := ApplySorting(db, s, obj)
db, sAssocToJoin, err := ApplySorting(ctx, db, s, obj)
if err != nil {
return nil, err
}
Expand All @@ -26,14 +29,14 @@ func ApplyCollectionOperators(db *gorm.DB, obj interface{}, f *query.Filtering,
for k := range sAssocToJoin {
fAssocToJoin[k] = struct{}{}
}
db, err = JoinAssociations(db, fAssocToJoin, obj)
db, err = JoinAssociations(ctx, db, fAssocToJoin, obj)
if err != nil {
return nil, err
}

db = ApplyPagination(db, p)
db = ApplyPagination(ctx, db, p)

db, err = ApplyFieldSelection(db, fs, obj)
db, err = ApplyFieldSelection(ctx, db, fs, obj)
if err != nil {
return nil, err
}
Expand All @@ -42,8 +45,8 @@ func ApplyCollectionOperators(db *gorm.DB, obj interface{}, f *query.Filtering,
}

// ApplyFiltering applies filtering operator f to gorm instance db.
func ApplyFiltering(db *gorm.DB, f *query.Filtering, obj interface{}) (*gorm.DB, map[string]struct{}, error) {
str, args, assocToJoin, err := FilteringToGorm(f, obj)
func ApplyFiltering(ctx context.Context, db *gorm.DB, f *query.Filtering, obj interface{}, pb proto.Message) (*gorm.DB, map[string]struct{}, error) {
str, args, assocToJoin, err := FilteringToGorm(ctx, f, obj, pb)
if err != nil {
return nil, nil, err
}
Expand All @@ -54,11 +57,11 @@ func ApplyFiltering(db *gorm.DB, f *query.Filtering, obj interface{}) (*gorm.DB,
}

// ApplySorting applies sorting operator s to gorm instance db.
func ApplySorting(db *gorm.DB, s *query.Sorting, obj interface{}) (*gorm.DB, map[string]struct{}, error) {
func ApplySorting(ctx context.Context, db *gorm.DB, s *query.Sorting, obj interface{}) (*gorm.DB, map[string]struct{}, error) {
var crs []string
var assocToJoin map[string]struct{}
for _, cr := range s.GetCriterias() {
dbName, assoc, err := HandleFieldPath(strings.Split(cr.GetTag(), "."), obj)
dbName, assoc, err := HandleFieldPath(ctx, strings.Split(cr.GetTag(), "."), obj)
if err != nil {
return nil, nil, err
}
Expand All @@ -81,9 +84,9 @@ func ApplySorting(db *gorm.DB, s *query.Sorting, obj interface{}) (*gorm.DB, map
}

// JoinAssociations joins obj's associations from assoc to the current gorm query.
func JoinAssociations(db *gorm.DB, assoc map[string]struct{}, obj interface{}) (*gorm.DB, error) {
func JoinAssociations(ctx context.Context, db *gorm.DB, assoc map[string]struct{}, obj interface{}) (*gorm.DB, error) {
for k := range assoc {
tableName, sourceKeys, targetKeys, err := JoinInfo(obj, k)
tableName, sourceKeys, targetKeys, err := JoinInfo(ctx, obj, k)
if err != nil {
return nil, err
}
Expand All @@ -97,16 +100,16 @@ func JoinAssociations(db *gorm.DB, assoc map[string]struct{}, obj interface{}) (
}

// ApplyPagination applies pagination operator p to gorm instance db.
func ApplyPagination(db *gorm.DB, p *query.Pagination) *gorm.DB {
func ApplyPagination(ctx context.Context, db *gorm.DB, p *query.Pagination) *gorm.DB {
if p != nil {
return db.Offset(p.GetOffset()).Limit(p.DefaultLimit())
}
return db
}

// ApplyFieldSelection applies field selection operator fs to gorm instance db.
func ApplyFieldSelection(db *gorm.DB, fs *query.FieldSelection, obj interface{}) (*gorm.DB, error) {
toPreload, err := FieldSelectionToGorm(fs, obj)
func ApplyFieldSelection(ctx context.Context, db *gorm.DB, fs *query.FieldSelection, obj interface{}) (*gorm.DB, error) {
toPreload, err := FieldSelectionToGorm(ctx, fs, obj)
if err != nil {
return nil, err
}
Expand Down
20 changes: 17 additions & 3 deletions gorm/collection_operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/infobloxopen/atlas-app-toolkit/gateway"
"github.com/infobloxopen/atlas-app-toolkit/query"
"github.com/jinzhu/gorm"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/infobloxopen/atlas-app-toolkit/gateway"
"github.com/infobloxopen/atlas-app-toolkit/query"
)

type Person struct {
Expand All @@ -22,6 +23,19 @@ type Person struct {
SubPerson SubPerson `gorm:"foreignkey:PersonId;association_foreignkey:Id"`
}

type PersonProto struct {
}

func (*PersonProto) Reset() {
}

func (*PersonProto) ProtoMessage() {
}

func (*PersonProto) String() string {
return "Person"
}

type SubPerson struct {
Id int64
Name string
Expand Down Expand Up @@ -71,7 +85,7 @@ func TestApplyCollectionOperators(t *testing.T) {
gormDB, mock := setUp(t)

rq := req.(*testRequest)
gormDB, err = ApplyCollectionOperators(gormDB, &Person{}, rq.Filtering, rq.Sorting, rq.Pagination, rq.FieldSelection)
gormDB, err = ApplyCollectionOperators(ctx, gormDB, &Person{}, &PersonProto{}, rq.Filtering, rq.Sorting, rq.Pagination, rq.FieldSelection)
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 4 additions & 3 deletions gorm/fields.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"fmt"
"reflect"
"sort"
Expand All @@ -12,12 +13,12 @@ import (

// FieldSelectionStringToGorm is a shortcut to parse a string into FieldSelection struct and
// receive a list of associations to preload.
func FieldSelectionStringToGorm(fs string, obj interface{}) ([]string, error) {
return FieldSelectionToGorm(query.ParseFieldSelection(fs), obj)
func FieldSelectionStringToGorm(ctx context.Context, fs string, obj interface{}) ([]string, error) {
return FieldSelectionToGorm(ctx, query.ParseFieldSelection(fs), obj)
}

// FieldSelectionToGorm receives FieldSelection struct and returns a list of associations to preload.
func FieldSelectionToGorm(fs *query.FieldSelection, obj interface{}) ([]string, error) {
func FieldSelectionToGorm(ctx context.Context, fs *query.FieldSelection, obj interface{}) ([]string, error) {
if fs == nil {
return nil, nil
}
Expand Down
6 changes: 4 additions & 2 deletions gorm/fields_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package gorm

import (
"github.com/stretchr/testify/assert"
"context"
"testing"

"github.com/stretchr/testify/assert"
)

type Model struct {
Expand Down Expand Up @@ -63,7 +65,7 @@ func TestGormFieldSelection(t *testing.T) {
},
}
for _, test := range tests {
toPreload, err := FieldSelectionStringToGorm(test.fs, &Model{})
toPreload, err := FieldSelectionStringToGorm(context.Background(), test.fs, &Model{})
if test.err {
assert.Nil(t, toPreload)
assert.NotNil(t, err)
Expand Down
111 changes: 88 additions & 23 deletions gorm/filtering.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
package gorm

import (
"context"
"fmt"
"reflect"

"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/generator"

"github.com/infobloxopen/atlas-app-toolkit/query"
"github.com/infobloxopen/atlas-app-toolkit/rpc/resource"
)

// FilterStringToGorm is a shortcut to parse a filter string using default FilteringParser implementation
// and call FilteringToGorm on the returned filtering expression.
func FilterStringToGorm(filter string, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func FilterStringToGorm(ctx context.Context, filter string, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
f, err := query.ParseFiltering(filter)
if err != nil {
return "", nil, nil, err
}
return FilteringToGorm(f, obj)
return FilteringToGorm(ctx, f, obj, pb)
}

// FilteringToGorm returns GORM Plain SQL representation of the filtering expression.
func FilteringToGorm(m *query.Filtering, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func FilteringToGorm(ctx context.Context, m *query.Filtering, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
if m == nil || m.Root == nil {
return "", nil, nil, nil
}
switch r := m.Root.(type) {
case *query.Filtering_Operator:
return LogicalOperatorToGorm(r.Operator, obj)
return LogicalOperatorToGorm(ctx, r.Operator, obj, pb)
case *query.Filtering_StringCondition:
return StringConditionToGorm(r.StringCondition, obj)
return StringConditionToGorm(ctx, r.StringCondition, obj, pb)
case *query.Filtering_NumberCondition:
return NumberConditionToGorm(r.NumberCondition, obj)
return NumberConditionToGorm(ctx, r.NumberCondition, obj, pb)
case *query.Filtering_NullCondition:
return NullConditionToGorm(r.NullCondition, obj)
return NullConditionToGorm(ctx, r.NullCondition, obj, pb)
default:
return "", nil, nil, fmt.Errorf("%T type is not supported in Filtering", r)
}
}

// LogicalOperatorToGorm returns GORM Plain SQL representation of the logical operator.
func LogicalOperatorToGorm(lop *query.LogicalOperator, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func LogicalOperatorToGorm(ctx context.Context, lop *query.LogicalOperator, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
var lres string
var largs []interface{}
var lAssocToJoin map[string]struct{}
var err error
switch l := lop.Left.(type) {
case *query.LogicalOperator_LeftOperator:
lres, largs, lAssocToJoin, err = LogicalOperatorToGorm(l.LeftOperator, obj)
lres, largs, lAssocToJoin, err = LogicalOperatorToGorm(ctx, l.LeftOperator, obj, pb)
case *query.LogicalOperator_LeftStringCondition:
lres, largs, lAssocToJoin, err = StringConditionToGorm(l.LeftStringCondition, obj)
lres, largs, lAssocToJoin, err = StringConditionToGorm(ctx, l.LeftStringCondition, obj, pb)
case *query.LogicalOperator_LeftNumberCondition:
lres, largs, lAssocToJoin, err = NumberConditionToGorm(l.LeftNumberCondition, obj)
lres, largs, lAssocToJoin, err = NumberConditionToGorm(ctx, l.LeftNumberCondition, obj, pb)
case *query.LogicalOperator_LeftNullCondition:
lres, largs, lAssocToJoin, err = NullConditionToGorm(l.LeftNullCondition, obj)
lres, largs, lAssocToJoin, err = NullConditionToGorm(ctx, l.LeftNullCondition, obj, pb)
default:
return "", nil, nil, fmt.Errorf("%T type is not supported in Filtering", l)
}
Expand All @@ -62,13 +69,13 @@ func LogicalOperatorToGorm(lop *query.LogicalOperator, obj interface{}) (string,
var rAssocToJoin map[string]struct{}
switch r := lop.Right.(type) {
case *query.LogicalOperator_RightOperator:
rres, rargs, rAssocToJoin, err = LogicalOperatorToGorm(r.RightOperator, obj)
rres, rargs, rAssocToJoin, err = LogicalOperatorToGorm(ctx, r.RightOperator, obj, pb)
case *query.LogicalOperator_RightStringCondition:
rres, rargs, rAssocToJoin, err = StringConditionToGorm(r.RightStringCondition, obj)
rres, rargs, rAssocToJoin, err = StringConditionToGorm(ctx, r.RightStringCondition, obj, pb)
case *query.LogicalOperator_RightNumberCondition:
rres, rargs, rAssocToJoin, err = NumberConditionToGorm(r.RightNumberCondition, obj)
rres, rargs, rAssocToJoin, err = NumberConditionToGorm(ctx, r.RightNumberCondition, obj, pb)
case *query.LogicalOperator_RightNullCondition:
rres, rargs, rAssocToJoin, err = NullConditionToGorm(r.RightNullCondition, obj)
rres, rargs, rAssocToJoin, err = NullConditionToGorm(ctx, r.RightNullCondition, obj, pb)
default:
return "", nil, nil, fmt.Errorf("%T type is not supported in Filtering", r)
}
Expand Down Expand Up @@ -98,9 +105,9 @@ func LogicalOperatorToGorm(lop *query.LogicalOperator, obj interface{}) (string,
}

// StringConditionToGorm returns GORM Plain SQL representation of the string condition.
func StringConditionToGorm(c *query.StringCondition, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func StringConditionToGorm(ctx context.Context, c *query.StringCondition, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
var assocToJoin map[string]struct{}
dbName, assoc, err := HandleFieldPath(c.FieldPath, obj)
dbName, assoc, err := HandleFieldPath(ctx, c.FieldPath, obj)
if err != nil {
return "", nil, nil, err
}
Expand Down Expand Up @@ -128,13 +135,71 @@ func StringConditionToGorm(c *query.StringCondition, obj interface{}) (string, [
neg = "NOT"
}

return fmt.Sprintf("%s(%s %s ?)", neg, dbName, o), []interface{}{c.Value}, assocToJoin, nil
var value interface{}
if v, err := processStringCondition(ctx, c, pb); err != nil {
value = c.Value
} else {
value = v
}

return fmt.Sprintf("%s(%s %s ?)", neg, dbName, o), []interface{}{value}, assocToJoin, nil
}

func processStringCondition(ctx context.Context, c *query.StringCondition, pb proto.Message) (interface{}, error) {
objType := indirectType(reflect.ValueOf(pb).Type())
pathLength := len(c.FieldPath)
for i, part := range c.FieldPath {
sf, ok := objType.FieldByName(generator.CamelCase(part))
if !ok {
return nil, fmt.Errorf("Cannot find field %s in %s", part, objType)
}
if i < pathLength-1 {
objType = indirectType(sf.Type)
if !isProtoMessage(objType) {
return nil, fmt.Errorf("%s: non-last field of %s field path should be a proto message", objType, c.FieldPath)
}
} else {
if isIdentifier(indirectType(sf.Type)) {
id := &resource.Identifier{}
if err := jsonpb.UnmarshalString(fmt.Sprintf("\"%s\"", c.Value), id); err != nil {
return nil, err
}
newPb := reflect.New(objType)
v := newPb.Elem().FieldByName(generator.CamelCase(part))
v.Set(reflect.ValueOf(id))
toOrm := newPb.MethodByName("ToORM")
if !toOrm.IsValid() {
return nil, fmt.Errorf("ToORM method cannot be found for %s", objType)
}
res := toOrm.Call([]reflect.Value{reflect.ValueOf(ctx)})
if len(res) != 2 {
return nil, fmt.Errorf("ToORM signature of %s is unknown", objType)
}
orm := res[0]
err := res[1]
if !err.IsNil() {
if tErr, ok := err.Interface().(error); ok {
return nil, tErr
} else {
return nil, fmt.Errorf("ToOrm second return value of %s is expected to be error", objType)
}
}
ormId := orm.FieldByName(generator.CamelCase(part))
if !ormId.IsValid() {
return nil, fmt.Errorf("Cannot find field %s in %s", part, objType)
}
return reflect.Indirect(ormId).Interface(), nil

}
}
}
return c.Value, nil
}

// NumberConditionToGorm returns GORM Plain SQL representation of the number condition.
func NumberConditionToGorm(c *query.NumberCondition, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func NumberConditionToGorm(ctx context.Context, c *query.NumberCondition, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
var assocToJoin map[string]struct{}
dbName, assoc, err := HandleFieldPath(c.FieldPath, obj)
dbName, assoc, err := HandleFieldPath(ctx, c.FieldPath, obj)
if err != nil {
return "", nil, nil, err
}
Expand Down Expand Up @@ -163,9 +228,9 @@ func NumberConditionToGorm(c *query.NumberCondition, obj interface{}) (string, [
}

// NullConditionToGorm returns GORM Plain SQL representation of the null condition.
func NullConditionToGorm(c *query.NullCondition, obj interface{}) (string, []interface{}, map[string]struct{}, error) {
func NullConditionToGorm(ctx context.Context, c *query.NullCondition, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) {
var assocToJoin map[string]struct{}
dbName, assoc, err := HandleFieldPath(c.FieldPath, obj)
dbName, assoc, err := HandleFieldPath(ctx, c.FieldPath, obj)
if err != nil {
return "", nil, nil, err
}
Expand Down
Loading

0 comments on commit 3237ef0

Please sign in to comment.