Skip to content

Commit

Permalink
Merge pull request #718 from abraithwaite/alan/fix-named-batch
Browse files Browse the repository at this point in the history
NamedExec Bulk Insert Fix
  • Loading branch information
jmoiron authored Apr 8, 2021
2 parents a1d5e64 + df9bf98 commit 1723f86
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 9 deletions.
19 changes: 13 additions & 6 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,28 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper)
return bound, arglist, nil
}

var valueBracketReg = regexp.MustCompile(`\([^(]*.[^(]\)\s*$`)
var valueBracketReg = regexp.MustCompile(`VALUES\s+(\([^(]*.[^(]\))`)

This comment has been minimized.

Copy link
@fifsky

fifsky Apr 8, 2021

This regularization is too strict, this leads to the failure of all the following cases:

insert into users(a,b) values(:a,:b)
insert into users(a,b) values (:a,:b)
insert into users(a,b) Values (:a,:b)
insert into users(a,b) VALUES(:a,:b)

This comment has been minimized.

Copy link
@jmoiron

jmoiron Apr 8, 2021

Author Owner

you're right, case insensitive match on VALUES and \s* instead of \s+ should fix those, going to work up some tests.


func fixBound(bound string, loop int) string {
loc := valueBracketReg.FindStringIndex(bound)
if len(loc) != 2 {
loc := valueBracketReg.FindAllStringSubmatchIndex(bound, -1)
// Either no VALUES () found or more than one found??
if len(loc) != 1 {
return bound
}
// defensive guard. loc should be len 4 representing the starting and
// ending index for the whole regex match and the starting + ending
// index for the single inside group
if len(loc[0]) != 4 {
return bound
}
var buffer bytes.Buffer

buffer.WriteString(bound[0:loc[1]])
buffer.WriteString(bound[0:loc[0][1]])
for i := 0; i < loop-1; i++ {
buffer.WriteString(",")
buffer.WriteString(bound[loc[0]:loc[1]])
buffer.WriteString(bound[loc[0][2]:loc[0][3]])
}
buffer.WriteString(bound[loc[1]:])
buffer.WriteString(bound[loc[0][1]:])
return buffer.String()
}

Expand Down
77 changes: 74 additions & 3 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sqlx
import (
"database/sql"
"fmt"
"regexp"
"testing"
)

Expand Down Expand Up @@ -202,7 +203,10 @@ func TestNamedQueries(t *testing.T) {
{FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"},
}

insert := fmt.Sprintf("INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", now)
insert := fmt.Sprintf(
"INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n",
now,
)
_, err = db.NamedExec(insert, sls)
test.Error(err)

Expand All @@ -214,7 +218,7 @@ func TestNamedQueries(t *testing.T) {
}

_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email) `, slsMap)
VALUES (:first_name, :last_name, :email) ;--`, slsMap)
test.Error(err)

type A map[string]interface{}
Expand All @@ -226,7 +230,7 @@ func TestNamedQueries(t *testing.T) {
}

_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email) `, typedMap)
VALUES (:first_name, :last_name, :email) ;--`, typedMap)
test.Error(err)

for _, p := range sls {
Expand Down Expand Up @@ -296,3 +300,70 @@ func TestNamedQueries(t *testing.T) {

})
}

func TestFixBounds(t *testing.T) {
table := []struct {
name, query, expect string
loop int
}{
{
name: `named syntax`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`,
loop: 2,
},
{
name: `mysql syntax`,
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`,
loop: 2,
},
{
name: `named syntax w/ trailer`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`,
loop: 2,
},
{
name: `mysql syntax w/ trailer`,
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`,
loop: 2,
},
{
name: `not found test`,
query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
loop: 2,
},
{
name: `found twice test`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
loop: 2,
},
}

for _, tc := range table {
t.Run(tc.name, func(t *testing.T) {
res := fixBound(tc.query, tc.loop)
if res != tc.expect {
t.Errorf("mismatched results")
}
})
}

t.Run("regex changed", func(t *testing.T) {
var valueBracketRegChanged = regexp.MustCompile(`(VALUES)\s+(\([^(]*.[^(]\))`)
saveRegexp := valueBracketReg
defer func() {
valueBracketReg = saveRegexp
}()
valueBracketReg = valueBracketRegChanged

res := fixBound("VALUES (:a, :b)", 2)
if res != "VALUES (:a, :b)" {
t.Errorf("changed regex should return string")
}
})
}

0 comments on commit 1723f86

Please sign in to comment.