Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Feature: support bulk insert to speed insert times #181

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 177 additions & 34 deletions testfixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path"
"path/filepath"
"regexp"
"sort"
"strings"
"text/template"
"time"
Expand All @@ -26,6 +27,7 @@ type Loader struct {
skipCleanup bool
skipTestDatabaseCheck bool
location *time.Location
bulkInsert bool

template bool
templateFuncs template.FuncMap
Expand Down Expand Up @@ -173,6 +175,13 @@ func UseDropConstraint() func(*Loader) error {
}
}

func UseBulkInsert() func(*Loader) error {
return func(l *Loader) error {
l.bulkInsert = true
return nil
}
}

// SkipResetSequences prevents Loader from reseting sequences after loading
// fixtures.
//
Expand Down Expand Up @@ -471,13 +480,18 @@ type InsertError struct {
}

func (e *InsertError) Error() string {
params := make([]string, len(e.Params))
for i, p := range e.Params {
params[i] = fmt.Sprintf("%v", p)
}

return fmt.Sprintf(
"testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
e.Err,
e.File,
e.Index,
e.SQL,
e.Params,
strings.Join(params, ", "),
)
}

Expand Down Expand Up @@ -510,19 +524,94 @@ func (l *Loader) buildInsertSQLs() error {

f.insertSQLs = make([]insertSQL, 0, len(result))

for _, record := range result {
recordMap, ok := record.(map[string]interface{})
if !ok {
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
if l.bulkInsert {
err = l.buildBulkInsertSQLs(result, f)
if err != nil {
return fmt.Errorf("testfixtures: could not build bulk insert sql: %w", err)
}
} else {
for _, record := range result {
recordMap, ok := record.(map[string]interface{})
if !ok {
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
}

sql, values, err := l.buildInsertSQL(f, recordMap)
if err != nil {
return err
}

f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
}
}

}

return nil
}

// buildBulkInsertSQLs builds the sql for bulk insert
// it will group the records by their columns and then build the sqls because each records can have different
// columns to inserts.
func (l *Loader) buildBulkInsertSQLs(result []interface{}, f *fixtureFile) error {
// in order to insert in bulk we need to group them by their columns because dynamic columns
// for bulk insert is not supported
// key of this map is columns separated by comma
groupsColumns := make(map[string][]map[string]any)
keyToColumns := make(map[string][]string)

for _, record := range result {
recordMap, ok := record.(map[string]any)
if !ok {
return fmt.Errorf("could not cast record: not a map[interface{}]interface{}")
}

columns := make([]string, 0, len(recordMap))

for key := range recordMap {
columns = append(columns, key)
}

sort.Strings(columns)

keyColumns := strings.Join(columns, ",")
recordMaps, ok := groupsColumns[keyColumns]

if !ok {
recordMaps = make([]map[string]any, 0)
keyToColumns[keyColumns] = columns
}

groupsColumns[strings.Join(columns, ",")] = append(recordMaps, recordMap)
}

for columnsCommaSeparated, recordMaps := range groupsColumns {
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n",
l.helper.quoteKeyword(f.fileNameWithoutExtension()),
columnsCommaSeparated,
)

totalValues := make([]any, 0)
var parameterizedValueLines []string
index := 0

for _, recordMap := range recordMaps {
colsInserts, values, err := l.getValuesForBulk(
&index, recordMap,
keyToColumns[columnsCommaSeparated],
)

sql, values, err := l.buildInsertSQL(f, recordMap)
if err != nil {
return err
}

f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
totalValues = append(totalValues, values...)
parameterizedValueLines = append(parameterizedValueLines, colsInserts)
}

sql += strings.Join(parameterizedValueLines, ",\n")

f.insertSQLs = append(f.insertSQLs, insertSQL{sql, totalValues})
}

return nil
Expand All @@ -540,6 +629,37 @@ func (f *fixtureFile) delete(tx *sql.Tx, h helper) error {
return nil
}

// getValueForBulk returns the strings sql for values and given values for those parameters
// first return is a string with the columns to insert example : (?, ?, ?)
// second return is a slice with the values to insert
func (l *Loader) getValuesForBulk(index *int, record map[string]any, orderedColumns []string) (string, []any, error) {
var (
sqlValues = make([]any, 0, len(record))
sqlColumns = make([]string, 0, len(record))
)

for _, col := range orderedColumns {

value := record[col]
appropriateValue, next, err := l.getValueForSQL(value)
if err != nil {
return "", nil, err
}

if next {
sqlColumns = append(sqlColumns, fmt.Sprint(appropriateValue))
continue
}

*index++

sqlColumns = append(sqlColumns, l.paramSQL(*index))
sqlValues = append(sqlValues, appropriateValue)
}

return fmt.Sprintf("(%s)", strings.Join(sqlColumns, ", ")), sqlValues, nil
}

func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) (sqlStr string, values []interface{}, err error) {
var (
sqlColumns = make([]string, 0, len(record))
Expand All @@ -549,37 +669,17 @@ func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) (
for key, value := range record {
sqlColumns = append(sqlColumns, l.helper.quoteKeyword(key))

// if string, try convert to SQL or time
// if map or array, convert to json
switch v := value.(type) {
case string:
if strings.HasPrefix(v, "RAW=") {
sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
continue
}
if b, err := l.tryHexStringToBytes(v); err == nil {
value = b
} else if t, err := l.tryStrToDate(v); err == nil {
value = t
}
case []interface{}, map[string]interface{}:
var bytes []byte
bytes, err = json.Marshal(recursiveToJSON(v))
if err != nil {
return
}
value = string(bytes)
value, next, err := l.getValueForSQL(value)
if err != nil {
return "", nil, err
}

switch l.helper.paramType() {
case paramTypeDollar:
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
case paramTypeQuestion:
sqlValues = append(sqlValues, "?")
case paramTypeAtSign:
sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i))
if next {
sqlValues = append(sqlValues, fmt.Sprint(value))
continue
}

sqlValues = append(sqlValues, l.paramSQL(i))
values = append(values, value)
i++
}
Expand All @@ -593,6 +693,49 @@ func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) (
return
}

func (l *Loader) paramSQL(index int) string {
switch l.helper.paramType() {
case paramTypeDollar:
return fmt.Sprintf("$%d", index)
case paramTypeAtSign:
return fmt.Sprintf("@p%d", index)
default:
return "?"
}
}

// getValueForSQL returns the value and parameter for the given value
// if string, try convert to SQL or time
// if map or array, convert to json
func (l *Loader) getValueForSQL(value any) (any, bool, error) {
switch v := value.(type) {
case string:
if strings.HasPrefix(v, "RAW=") {
return strings.TrimPrefix(v, "RAW="), true, nil
}

// because sometimes you have strings that are interpreted as dates ("20060102" especially)
if strings.HasPrefix(v, "TEXT=") {
return strings.TrimPrefix(v, "TEXT="), false, nil
}

if b, err := l.tryHexStringToBytes(v); err == nil {
return b, false, nil
} else if t, err := l.tryStrToDate(v); err == nil {
return t, false, nil
}
case []interface{}, map[string]interface{}:
var bytes []byte
bytes, err := json.Marshal(recursiveToJSON(v))
if err != nil {
return nil, false, err
}
return string(bytes), false, nil
}

return value, false, nil
}

func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) {
fileinfos, err := fs.ReadDir(l.fs, dir)
if err != nil {
Expand Down