diff --git a/testfixtures.go b/testfixtures.go index ee53827..198b997 100644 --- a/testfixtures.go +++ b/testfixtures.go @@ -10,6 +10,7 @@ import ( "path" "path/filepath" "regexp" + "sort" "strings" "text/template" "time" @@ -26,6 +27,7 @@ type Loader struct { skipCleanup bool skipTestDatabaseCheck bool location *time.Location + bulkInsert bool template bool templateFuncs template.FuncMap @@ -173,6 +175,13 @@ func UseDropConstraint() func(*Loader) error { } } +func UseBulkInsert() func(*Loader) error { + return func(l *Loader) error { + l.bulkInsert = true + return nil + } +} + // SkipResetSequences prevents Loader from reseting sequences after loading // fixtures. // @@ -471,13 +480,18 @@ type InsertError struct { } func (e *InsertError) Error() string { + params := make([]string, len(e.Params)) + for i, p := range e.Params { + params[i] = fmt.Sprintf("%v", p) + } + return fmt.Sprintf( "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v", e.Err, e.File, e.Index, e.SQL, - e.Params, + strings.Join(params, ", "), ) } @@ -510,19 +524,94 @@ func (l *Loader) buildInsertSQLs() error { f.insertSQLs = make([]insertSQL, 0, len(result)) - for _, record := range result { - recordMap, ok := record.(map[string]interface{}) - if !ok { - return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}") + if l.bulkInsert { + err = l.buildBulkInsertSQLs(result, f) + if err != nil { + return fmt.Errorf("testfixtures: could not build bulk insert sql: %w", err) } + } else { + for _, record := range result { + recordMap, ok := record.(map[string]interface{}) + if !ok { + return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}") + } + + sql, values, err := l.buildInsertSQL(f, recordMap) + if err != nil { + return err + } + + f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values}) + } + } + + } + + return nil +} + +// buildBulkInsertSQLs builds the sql for bulk insert +// it will group the records by their columns and then build the sqls because each records can have different +// columns to inserts. +func (l *Loader) buildBulkInsertSQLs(result []interface{}, f *fixtureFile) error { + // in order to insert in bulk we need to group them by their columns because dynamic columns + // for bulk insert is not supported + // key of this map is columns separated by comma + groupsColumns := make(map[string][]map[string]any) + keyToColumns := make(map[string][]string) + + for _, record := range result { + recordMap, ok := record.(map[string]any) + if !ok { + return fmt.Errorf("could not cast record: not a map[interface{}]interface{}") + } + + columns := make([]string, 0, len(recordMap)) + + for key := range recordMap { + columns = append(columns, key) + } + + sort.Strings(columns) + + keyColumns := strings.Join(columns, ",") + recordMaps, ok := groupsColumns[keyColumns] + + if !ok { + recordMaps = make([]map[string]any, 0) + keyToColumns[keyColumns] = columns + } + + groupsColumns[strings.Join(columns, ",")] = append(recordMaps, recordMap) + } + + for columnsCommaSeparated, recordMaps := range groupsColumns { + sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n", + l.helper.quoteKeyword(f.fileNameWithoutExtension()), + columnsCommaSeparated, + ) + + totalValues := make([]any, 0) + var parameterizedValueLines []string + index := 0 + + for _, recordMap := range recordMaps { + colsInserts, values, err := l.getValuesForBulk( + &index, recordMap, + keyToColumns[columnsCommaSeparated], + ) - sql, values, err := l.buildInsertSQL(f, recordMap) if err != nil { return err } - f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values}) + totalValues = append(totalValues, values...) + parameterizedValueLines = append(parameterizedValueLines, colsInserts) } + + sql += strings.Join(parameterizedValueLines, ",\n") + + f.insertSQLs = append(f.insertSQLs, insertSQL{sql, totalValues}) } return nil @@ -540,6 +629,37 @@ func (f *fixtureFile) delete(tx *sql.Tx, h helper) error { return nil } +// getValueForBulk returns the strings sql for values and given values for those parameters +// first return is a string with the columns to insert example : (?, ?, ?) +// second return is a slice with the values to insert +func (l *Loader) getValuesForBulk(index *int, record map[string]any, orderedColumns []string) (string, []any, error) { + var ( + sqlValues = make([]any, 0, len(record)) + sqlColumns = make([]string, 0, len(record)) + ) + + for _, col := range orderedColumns { + + value := record[col] + appropriateValue, next, err := l.getValueForSQL(value) + if err != nil { + return "", nil, err + } + + if next { + sqlColumns = append(sqlColumns, fmt.Sprint(appropriateValue)) + continue + } + + *index++ + + sqlColumns = append(sqlColumns, l.paramSQL(*index)) + sqlValues = append(sqlValues, appropriateValue) + } + + return fmt.Sprintf("(%s)", strings.Join(sqlColumns, ", ")), sqlValues, nil +} + func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) (sqlStr string, values []interface{}, err error) { var ( sqlColumns = make([]string, 0, len(record)) @@ -549,37 +669,17 @@ func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) ( for key, value := range record { sqlColumns = append(sqlColumns, l.helper.quoteKeyword(key)) - // if string, try convert to SQL or time - // if map or array, convert to json - switch v := value.(type) { - case string: - if strings.HasPrefix(v, "RAW=") { - sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW=")) - continue - } - if b, err := l.tryHexStringToBytes(v); err == nil { - value = b - } else if t, err := l.tryStrToDate(v); err == nil { - value = t - } - case []interface{}, map[string]interface{}: - var bytes []byte - bytes, err = json.Marshal(recursiveToJSON(v)) - if err != nil { - return - } - value = string(bytes) + value, next, err := l.getValueForSQL(value) + if err != nil { + return "", nil, err } - switch l.helper.paramType() { - case paramTypeDollar: - sqlValues = append(sqlValues, fmt.Sprintf("$%d", i)) - case paramTypeQuestion: - sqlValues = append(sqlValues, "?") - case paramTypeAtSign: - sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i)) + if next { + sqlValues = append(sqlValues, fmt.Sprint(value)) + continue } + sqlValues = append(sqlValues, l.paramSQL(i)) values = append(values, value) i++ } @@ -593,6 +693,49 @@ func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) ( return } +func (l *Loader) paramSQL(index int) string { + switch l.helper.paramType() { + case paramTypeDollar: + return fmt.Sprintf("$%d", index) + case paramTypeAtSign: + return fmt.Sprintf("@p%d", index) + default: + return "?" + } +} + +// getValueForSQL returns the value and parameter for the given value +// if string, try convert to SQL or time +// if map or array, convert to json +func (l *Loader) getValueForSQL(value any) (any, bool, error) { + switch v := value.(type) { + case string: + if strings.HasPrefix(v, "RAW=") { + return strings.TrimPrefix(v, "RAW="), true, nil + } + + // because sometimes you have strings that are interpreted as dates ("20060102" especially) + if strings.HasPrefix(v, "TEXT=") { + return strings.TrimPrefix(v, "TEXT="), false, nil + } + + if b, err := l.tryHexStringToBytes(v); err == nil { + return b, false, nil + } else if t, err := l.tryStrToDate(v); err == nil { + return t, false, nil + } + case []interface{}, map[string]interface{}: + var bytes []byte + bytes, err := json.Marshal(recursiveToJSON(v)) + if err != nil { + return nil, false, err + } + return string(bytes), false, nil + } + + return value, false, nil +} + func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) { fileinfos, err := fs.ReadDir(l.fs, dir) if err != nil {