Skip to content

Commit

Permalink
Merge pull request #332 from arnehormann/uint64params
Browse files Browse the repository at this point in the history
support uint64 parameters with high bit set
  • Loading branch information
arnehormann committed May 3, 2015
2 parents 7ff0b8c + b2cd472 commit 0cc29e9
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
43 changes: 43 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,49 @@ func TestNULL(t *testing.T) {
})
}

func TestUint64(t *testing.T) {
const (
u0 = uint64(0)
uall = ^u0
uhigh = uall >> 1
utop = ^uhigh
s0 = int64(0)
sall = ^s0
shigh = int64(uhigh)
stop = ^shigh
)
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
if err != nil {
dbt.Fatal(err)
}
defer stmt.Close()
row := stmt.QueryRow(
u0, uhigh, utop, uall,
s0, shigh, stop, sall,
)

var ua, ub, uc, ud uint64
var sa, sb, sc, sd int64

err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
if err != nil {
dbt.Fatal(err)
}
switch {
case ua != u0,
ub != uhigh,
uc != utop,
ud != uall,
sa != s0,
sb != shigh,
sc != stop,
sd != sall:
dbt.Fatal("Unexpected result value")
}
})
}

func TestLongData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var maxAllowedPacketSize int
Expand Down
37 changes: 37 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ package mysql

import (
"database/sql/driver"
"fmt"
"reflect"
)

type mysqlStmt struct {
Expand All @@ -34,6 +36,10 @@ func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}

func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
Expand Down Expand Up @@ -110,3 +116,34 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {

return rows, err
}

type converter struct{}

func (converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}

rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return fmt.Sprintf("%d", u64), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

0 comments on commit 0cc29e9

Please sign in to comment.