diff --git a/query_test.go b/query_test.go index bc23088b7..f738ff6bc 100644 --- a/query_test.go +++ b/query_test.go @@ -133,6 +133,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } + func TestStringAgainstIncompleteParentheses(t *testing.T) { type AddressByZipCode struct { ZipCode string `gorm:"primary_key"` @@ -151,6 +152,17 @@ func TestStringAgainstIncompleteParentheses(t *testing.T) { } +func TestStringAgainstIncompleteParenthesesQuoted(t *testing.T) { + DB.Save(&User{Name: "name-)-surname"}) + + var user User + res := DB.Raw("select * from users WHERE name = 'name-)-surname'").First(&user) + + if res.Error != nil { + t.Errorf("Can't execute valid query because error : %s", res.Error.Error()) + } +} + func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope.go b/scope.go index 54a20c8b4..8c7ef8810 100644 --- a/scope.go +++ b/scope.go @@ -280,16 +280,23 @@ func (scope *Scope) AddToVars(value interface{}) string { // IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection func (scope *Scope) IsCompleteParentheses(value string) bool { count := 0 - for i, _ := range value { - if value[i] == 40 { // ( - count++ - } else if value[i] == 41 { // ) - count-- + unquoted := true + for _, ch := range value { + switch ch { + case '(': + if unquoted { + count++ + } + case ')': + if unquoted { + count-- + } + case '\'': + unquoted = unquoted != true } if count < 0 { break } - i++ } return count == 0 }