@@ -9,9 +9,11 @@ import (
99 "crypto/tls"
1010 "crypto/x509"
1111 "database/sql"
12+ "database/sql/driver"
1213 "encoding/json"
1314 "fmt"
1415 "io/ioutil"
16+ "reflect"
1517 "strconv"
1618 "time"
1719
@@ -203,7 +205,7 @@ func (m *Mysql) query(s string) ([]byte, error) {
203205 _ = rows .Err ()
204206 }()
205207
206- result , err := jsonify (rows )
208+ result , err := m . jsonify (rows )
207209 if err != nil {
208210 return nil , errors .Wrapf (err , "error marshalling query result for %s" , s )
209211 }
@@ -277,92 +279,66 @@ func initDB(url, pemPath string) (*sql.DB, error) {
277279 return db , nil
278280}
279281
280- func jsonify (rows * sql.Rows ) ([]byte , error ) {
282+ func ( m * Mysql ) jsonify (rows * sql.Rows ) ([]byte , error ) {
281283 columnTypes , err := rows .ColumnTypes ()
282284 if err != nil {
283285 return nil , err
284286 }
285287
286288 var ret []interface {}
287289 for rows .Next () {
288- scanArgs := prepareScanArgs (columnTypes )
289- err := rows .Scan (scanArgs ... )
290+ values := prepareValues (columnTypes )
291+ err := rows .Scan (values ... )
290292 if err != nil {
291293 return nil , err
292294 }
293295
294- r := convertScanArgs (columnTypes , scanArgs )
296+ r := m . convert (columnTypes , values )
295297 ret = append (ret , r )
296298 }
297299
298300 return json .Marshal (ret )
299301}
300302
301- func convertScanArgs (columnTypes []* sql.ColumnType , scanArgs []interface {}) map [string ]interface {} {
302- r := map [string ]interface {}{}
303-
304- for i , v := range columnTypes {
305- if s , ok := (scanArgs [i ]).(* sql.NullString ); ok {
306- r [v .Name ()] = s .String
307-
308- continue
309- }
310-
311- if s , ok := (scanArgs [i ]).(* sql.NullBool ); ok {
312- r [v .Name ()] = s .Bool
313-
314- continue
315- }
316-
317- if s , ok := (scanArgs [i ]).(* sql.NullInt32 ); ok {
318- r [v .Name ()] = s .Int32
319-
320- continue
321- }
303+ func prepareValues (columnTypes []* sql.ColumnType ) []interface {} {
304+ types := make ([]reflect.Type , len (columnTypes ))
305+ for i , tp := range columnTypes {
306+ types [i ] = tp .ScanType ()
307+ }
322308
323- if s , ok := (scanArgs [i ]).(* sql.NullInt64 ); ok {
324- r [v .Name ()] = s .Int64
309+ values := make ([]interface {}, len (columnTypes ))
310+ for i := range values {
311+ values [i ] = reflect .New (types [i ]).Interface ()
312+ }
325313
326- continue
327- }
314+ return values
315+ }
328316
329- if s , ok := ( scanArgs [ i ]).( * sql.NullFloat64 ); ok {
330- r [ v . Name ()] = s . Float64
317+ func ( m * Mysql ) convert ( columnTypes [] * sql.ColumnType , values [] interface {}) map [ string ] interface {} {
318+ r := map [ string ] interface {}{}
331319
332- continue
320+ for i , ct := range columnTypes {
321+ value := values [i ]
322+
323+ switch v := values [i ].(type ) {
324+ case driver.Valuer :
325+ if vv , err := v .Value (); err == nil {
326+ value = interface {}(vv )
327+ } else {
328+ m .logger .Warnf ("error to convert value: %v" , err )
329+ }
330+ case * sql.RawBytes :
331+ // special case for sql.RawBytes, see https://github.com/go-sql-driver/mysql/blob/master/fields.go#L178
332+ switch ct .DatabaseTypeName () {
333+ case "VARCHAR" , "CHAR" :
334+ value = string (* v )
335+ }
333336 }
334337
335- if s , ok := (scanArgs [i ]).(* sql.NullTime ); ok {
336- r [v .Name ()] = s .Time
337-
338- continue
338+ if value != nil {
339+ r [ct .Name ()] = value
339340 }
340-
341- // this won't happen since the default switch is sql.NullString
342- r [v .Name ()] = scanArgs [i ]
343341 }
344342
345343 return r
346344}
347-
348- func prepareScanArgs (columnTypes []* sql.ColumnType ) []interface {} {
349- scanArgs := make ([]interface {}, len (columnTypes ))
350- for i , v := range columnTypes {
351- switch v .DatabaseTypeName () {
352- case "BOOL" :
353- scanArgs [i ] = new (sql.NullBool )
354- case "INT" , "MEDIUMINT" , "SMALLINT" , "CHAR" , "TINYINT" :
355- scanArgs [i ] = new (sql.NullInt32 )
356- case "BIGINT" :
357- scanArgs [i ] = new (sql.NullInt64 )
358- case "DOUBLE" , "FLOAT" , "DECIMAL" :
359- scanArgs [i ] = new (sql.NullFloat64 )
360- case "DATE" , "TIME" , "YEAR" :
361- scanArgs [i ] = new (sql.NullTime )
362- default :
363- scanArgs [i ] = new (sql.NullString )
364- }
365- }
366-
367- return scanArgs
368- }
0 commit comments