diff --git a/bind.go b/bind.go index b272d2306a..406ca6562f 100644 --- a/bind.go +++ b/bind.go @@ -110,41 +110,64 @@ func checkAllNamedArguments(args ...any) (bool, error) { return haveNamed, nil } -var bindPositionCharRe = regexp.MustCompile(`[?]`) - func bindPositional(tz *time.Location, query string, args ...any) (_ string, err error) { var ( - unbind = make(map[int]struct{}) - params = make([]string, len(args)) + lastMatchIndex = -1 // Position of previous match for copying + argIndex = 0 // Index for the argument at current position + buf = make([]byte, 0, len(query)) + unbindCount = 0 // Number of positional arguments that couldn't be matched ) - for i, v := range args { - if fn, ok := v.(std_driver.Valuer); ok { - if v, err = fn.Value(); err != nil { - return "", nil + + for i := 0; i < len(query); i++ { + // It's fine looping through the query string as bytes, because the (fixed) characters we're looking for + // are in the ASCII range to won't take up more than one byte. + if query[i] == '?' { + if i > 0 && query[i-1] == '\\' { + // Copy all previous index to here characters + buf = append(buf, query[lastMatchIndex+1:i-1]...) + buf = append(buf, '?') + } else { + // Copy all previous index to here characters + buf = append(buf, query[lastMatchIndex+1:i]...) + + // Append the argument value + if argIndex < len(args) { + v := args[argIndex] + if fn, ok := v.(std_driver.Valuer); ok { + if v, err = fn.Value(); err != nil { + return "", nil + } + } + + value, err := format(tz, Seconds, v) + if err != nil { + return "", err + } + + buf = append(buf, value...) + argIndex++ + } else { + unbindCount++ + } } - } - params[i], err = format(tz, Seconds, v) - if err != nil { - return "", err + + lastMatchIndex = i } } - i := 0 - query = bindPositionalRe.ReplaceAllStringFunc(query, func(n string) string { - if i >= len(params) { - unbind[i] = struct{}{} - return "" - } - val := params[i] - i++ - return bindPositionCharRe.ReplaceAllStringFunc(n, func(m string) string { - return val - }) - }) - for param := range unbind { - return "", fmt.Errorf("have no arg for param ? at position %d", param) + + // If there were no replacements, quick return without copying the string + if lastMatchIndex < 0 { + return query, nil + } + + // Append the remainder + buf = append(buf, query[lastMatchIndex+1:]...) + + if unbindCount > 0 { + return "", fmt.Errorf("have no arg for param ? at last %d positions", unbindCount) } - // replace \? escape sequence - return strings.ReplaceAll(query, "\\?", "?"), nil + + return string(buf), nil } func bindNumeric(tz *time.Location, query string, args ...any) (_ string, err error) { @@ -244,9 +267,11 @@ func formatTime(tz *time.Location, scale TimeUnit, value time.Time) (string, err return fmt.Sprintf("toDateTime64('%s', %d, '%s')", value.Format(fmt.Sprintf("2006-01-02 15:04:05.%0*d", int(scale*3), 0)), int(scale*3), value.Location().String()), nil } +var stringQuoteReplacer = strings.NewReplacer(`\`, `\\`, `'`, `\'`) + func format(tz *time.Location, scale TimeUnit, v any) (string, error) { quote := func(v string) string { - return "'" + strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(v) + "'" + return "'" + stringQuoteReplacer.Replace(v) + "'" } switch v := v.(type) { case nil: