Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamess-Lucass committed Nov 9, 2024
1 parent e09542f commit 0331a83
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 25 deletions.
66 changes: 43 additions & 23 deletions module/gorm/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ func Apply[T any](db *gorm.DB, query goatquery.Query, searchFunc SearchFunc, opt
if options != nil && query.Top > options.MaxTop {
return nil, nil, fmt.Errorf("The value supplied for the query parameter 'Top' was greater than the maximum top allowed for this resource")
}
var err error

var model T

// v := reflect.ValueOf(model)
t := reflect.TypeOf(model)

namer := db.Statement.NamingStrategy
tableName := namer.TableName(t.Name())

Expand All @@ -42,7 +44,10 @@ func Apply[T any](db *gorm.DB, query goatquery.Query, searchFunc SearchFunc, opt

statements := p.ParseFilter()

db = EvaluateFilter(statements.Expression, db, namer, tableName, t)
db, err = EvaluateFilter(statements.Expression, db, namer, tableName, t)
if err != nil {
return db, nil, err
}
}

// Search
Expand All @@ -64,9 +69,12 @@ func Apply[T any](db *gorm.DB, query goatquery.Query, searchFunc SearchFunc, opt
statements := p.ParseOrderBy()

for _, statement := range statements {
property := GetGormColumnName(namer, tableName, t, statement.TokenLiteral())
property, err := GetGormColumnName(namer, tableName, t, statement.TokenLiteral())
if err != nil {
return db, &count, err
}

sql := fmt.Sprintf("%s %s", property, statement.Direction)
sql := fmt.Sprintf("%s %s", *property, statement.Direction)

db = db.Order(sql)
}
Expand All @@ -93,13 +101,14 @@ func Apply[T any](db *gorm.DB, query goatquery.Query, searchFunc SearchFunc, opt
return db, nil, nil
}

func GetGormColumnName(namer schema.Namer, tableName string, t reflect.Type, property string) string {
func GetGormColumnName(namer schema.Namer, tableName string, t reflect.Type, property string) (*string, error) {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)

if field.Anonymous {
if columnName := GetGormColumnName(namer, tableName, field.Type, property); columnName != "" {
return columnName
columnName, err := GetGormColumnName(namer, tableName, field.Type, property)
if err == nil && *columnName != "" {
return columnName, nil
}
continue
}
Expand All @@ -112,17 +121,19 @@ func GetGormColumnName(namer schema.Namer, tableName string, t reflect.Type, pro
if strings.EqualFold(propertyName, property) {
settings := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
if settings["COLUMN"] != "" {
return settings["COLUMN"]
col := settings["COLUMN"]
return &col, nil
}

return namer.ColumnName(tableName, field.Name)
col := namer.ColumnName(tableName, field.Name)
return &col, nil
}
}

return namer.ColumnName(tableName, property)
return nil, fmt.Errorf("Property doesn't exist")
}

func EvaluateFilter(exp ast.Expression, db *gorm.DB, namer schema.Namer, tableName string, t reflect.Type) *gorm.DB {
func EvaluateFilter(exp ast.Expression, db *gorm.DB, namer schema.Namer, tableName string, t reflect.Type) (*gorm.DB, error) {
switch exp := exp.(type) {
case *ast.InfixExpression:
identifier, ok := exp.Left.(*ast.Identifier)
Expand All @@ -142,40 +153,49 @@ func EvaluateFilter(exp ast.Expression, db *gorm.DB, namer schema.Namer, tableNa
value = right.Value
}

property := GetGormColumnName(namer, namer.TableName(t.Name()), t, identifier.TokenLiteral())
property, err := GetGormColumnName(namer, namer.TableName(t.Name()), t, identifier.TokenLiteral())
if err != nil {
return db, err
}

switch strings.ToLower(exp.Operator) {
case keywords.EQ:
return db.Where(fmt.Sprintf("%s = ?", property), value)
return db.Where(fmt.Sprintf("%s = ?", *property), value), nil
case keywords.NE:
return db.Where(fmt.Sprintf("%s <> ?", property), value)
return db.Where(fmt.Sprintf("%s <> ?", *property), value), nil
case keywords.CONTAINS:
if str, ok := exp.Right.(*ast.StringLiteral); ok {
return db.Where(fmt.Sprintf("%s LIKE ?", property), "%"+str.Value+"%")
return db.Where(fmt.Sprintf("%s LIKE ?", *property), "%"+str.Value+"%"), nil
}
case keywords.LT:
return db.Where(fmt.Sprintf("%s < ?", property), value)
return db.Where(fmt.Sprintf("%s < ?", *property), value), nil
case keywords.LTE:
return db.Where(fmt.Sprintf("%s <= ?", property), value)
return db.Where(fmt.Sprintf("%s <= ?", *property), value), nil
case keywords.GT:
return db.Where(fmt.Sprintf("%s > ?", property), value)
return db.Where(fmt.Sprintf("%s > ?", *property), value), nil
case keywords.GTE:
return db.Where(fmt.Sprintf("%s >= ?", property), value)
return db.Where(fmt.Sprintf("%s >= ?", *property), value), nil
}
}

left := EvaluateFilter(exp.Left, db, namer, tableName, t)
right := EvaluateFilter(exp.Right, db, namer, tableName, t)
left, err := EvaluateFilter(exp.Left, db, namer, tableName, t)
if err != nil {
return db, err
}
right, err := EvaluateFilter(exp.Right, db, namer, tableName, t)
if err != nil {
return db, err
}

switch exp.Operator {
case keywords.AND:
return left.Where(right)
return left.Where(right), nil
case keywords.OR:
return left.Or(right)
return left.Or(right), nil
}

break
}

return db
return db, nil
}
32 changes: 30 additions & 2 deletions module/gorm/apply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type User struct {
Firstname string
Balance *float64
DateOfBirth time.Time
JsonProp string `json:"random_json_name"`
}

var DB *gorm.DB
Expand All @@ -50,7 +51,7 @@ func timeMustParse(value string) time.Time {
}

var users = map[string]User{
"John": {Base: Base{Age: 2}, Firstname: "John", UserId: uuid.MustParse("58cdeca3-645b-457c-87aa-7d5f87734255"), DateOfBirth: timeMustParse("2004-01-31 23:59:59"), Balance: makePointer(1.50)},
"John": {Base: Base{Age: 2}, Firstname: "John", UserId: uuid.MustParse("58cdeca3-645b-457c-87aa-7d5f87734255"), DateOfBirth: timeMustParse("2004-01-31 23:59:59"), Balance: makePointer(1.50), JsonProp: "user_john"},
"Jane": {Base: Base{Age: 1}, Firstname: "Jane", UserId: uuid.MustParse("58cdeca3-645b-457c-87aa-7d5f87734255"), DateOfBirth: timeMustParse("2020-05-09 15:30:00"), Balance: makePointer(0.0)},
"Apple": {Base: Base{Age: 2}, Firstname: "Apple", UserId: uuid.MustParse("58cdeca3-645b-457c-87aa-7d5f87734255"), DateOfBirth: timeMustParse("1980-12-31 00:00:01"), Balance: makePointer(1204050.98)},
"Harry": {Base: Base{Age: 1}, Firstname: "Harry", UserId: uuid.MustParse("e4c7772b-8947-4e46-98ed-644b417d2a08"), DateOfBirth: timeMustParse("2002-08-01"), Balance: makePointer(0.5372958205929493)},
Expand All @@ -60,7 +61,7 @@ var users = map[string]User{

func setup() {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Error),
Logger: logger.Default.LogMode(logger.Info),
})

if err != nil {
Expand Down Expand Up @@ -564,6 +565,33 @@ func Test_InvalidFilterReturnsError(t *testing.T) {
assert.Error(t, err)
}

func Test_Filter_WithCustomJsonTag(t *testing.T) {
tests := []struct {
input string
expected []User
}{
{"random_json_name eq 'user_john'", []User{
users["John"],
}},
}

for _, test := range tests {

query := goatquery.Query{
Filter: test.input,
}

res, _, err := Apply[User](DB, query, nil, nil)
assert.NoError(t, err)

var output []User
err = res.Find(&output).Error
assert.NoError(t, err)

assert.Equal(t, test.expected, output)
}
}

func makePointer[T any](v T) *T {
return &v
}

0 comments on commit 0331a83

Please sign in to comment.