-
-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from mtsoltan/fix-scan-value
Fix option to handle type conversion
- Loading branch information
Showing
5 changed files
with
382 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,4 @@ go.work | |
cover.out | ||
cover.html | ||
.vscode | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,327 @@ | ||
//go:build !go1.22 | ||
// +build !go1.22 | ||
|
||
package mo | ||
|
||
// | ||
// sql.Null[T] has been introduce in go1.22 | ||
// This file is a copy of stdlib and ensure retro-compatibility. | ||
// See https://github.com/samber/mo/pull/49 | ||
// | ||
|
||
import ( | ||
"database/sql" | ||
"database/sql/driver" | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
"strconv" | ||
"time" | ||
) | ||
|
||
var errNilPtr = errors.New("destination pointer is nil") | ||
|
||
func cloneBytes(b []byte) []byte { | ||
if b == nil { | ||
return nil | ||
} | ||
c := make([]byte, len(b)) | ||
copy(c, b) | ||
return c | ||
} | ||
|
||
func asString(src any) string { | ||
switch v := src.(type) { | ||
case string: | ||
return v | ||
case []byte: | ||
return string(v) | ||
} | ||
rv := reflect.ValueOf(src) | ||
switch rv.Kind() { | ||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||
return strconv.FormatInt(rv.Int(), 10) | ||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||
return strconv.FormatUint(rv.Uint(), 10) | ||
case reflect.Float64: | ||
return strconv.FormatFloat(rv.Float(), 'g', -1, 64) | ||
case reflect.Float32: | ||
return strconv.FormatFloat(rv.Float(), 'g', -1, 32) | ||
case reflect.Bool: | ||
return strconv.FormatBool(rv.Bool()) | ||
} | ||
return fmt.Sprintf("%v", src) | ||
} | ||
|
||
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { | ||
switch rv.Kind() { | ||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||
return strconv.AppendInt(buf, rv.Int(), 10), true | ||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||
return strconv.AppendUint(buf, rv.Uint(), 10), true | ||
case reflect.Float32: | ||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true | ||
case reflect.Float64: | ||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true | ||
case reflect.Bool: | ||
return strconv.AppendBool(buf, rv.Bool()), true | ||
case reflect.String: | ||
s := rv.String() | ||
return append(buf, s...), true | ||
} | ||
return | ||
} | ||
|
||
func strconvErr(err error) error { | ||
if ne, ok := err.(*strconv.NumError); ok { | ||
return ne.Err | ||
} | ||
return err | ||
} | ||
|
||
// convertAssignRows copies to dest the value in src, converting it if possible. | ||
// An error is returned if the copy would result in loss of information. | ||
// dest should be a pointer type. If rows is passed in, the rows will | ||
// be used as the parent for any cursor values converted from a | ||
// driver.Rows to a *Rows. | ||
func convertAssign(dest, src any) error { | ||
// Common cases, without reflect. | ||
switch s := src.(type) { | ||
case string: | ||
switch d := dest.(type) { | ||
case *string: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = s | ||
return nil | ||
case *[]byte: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = []byte(s) | ||
return nil | ||
case *sql.RawBytes: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = append((*d)[:0], s...) | ||
return nil | ||
} | ||
case []byte: | ||
switch d := dest.(type) { | ||
case *string: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = string(s) | ||
return nil | ||
case *any: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = cloneBytes(s) | ||
return nil | ||
case *[]byte: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = cloneBytes(s) | ||
return nil | ||
case *sql.RawBytes: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = s | ||
return nil | ||
} | ||
case time.Time: | ||
switch d := dest.(type) { | ||
case *time.Time: | ||
*d = s | ||
return nil | ||
case *string: | ||
*d = s.Format(time.RFC3339Nano) | ||
return nil | ||
case *[]byte: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = []byte(s.Format(time.RFC3339Nano)) | ||
return nil | ||
case *sql.RawBytes: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano) | ||
return nil | ||
} | ||
case nil: | ||
switch d := dest.(type) { | ||
case *any: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = nil | ||
return nil | ||
case *[]byte: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = nil | ||
return nil | ||
case *sql.RawBytes: | ||
if d == nil { | ||
return errNilPtr | ||
} | ||
*d = nil | ||
return nil | ||
} | ||
} | ||
|
||
var sv reflect.Value | ||
|
||
switch d := dest.(type) { | ||
case *string: | ||
sv = reflect.ValueOf(src) | ||
switch sv.Kind() { | ||
case reflect.Bool, | ||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, | ||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, | ||
reflect.Float32, reflect.Float64: | ||
*d = asString(src) | ||
return nil | ||
} | ||
case *[]byte: | ||
sv = reflect.ValueOf(src) | ||
if b, ok := asBytes(nil, sv); ok { | ||
*d = b | ||
return nil | ||
} | ||
case *sql.RawBytes: | ||
sv = reflect.ValueOf(src) | ||
if b, ok := asBytes([]byte(*d)[:0], sv); ok { | ||
*d = sql.RawBytes(b) | ||
return nil | ||
} | ||
case *bool: | ||
bv, err := driver.Bool.ConvertValue(src) | ||
if err == nil { | ||
*d = bv.(bool) | ||
} | ||
return err | ||
case *any: | ||
*d = src | ||
return nil | ||
} | ||
|
||
if scanner, ok := dest.(sql.Scanner); ok { | ||
return scanner.Scan(src) | ||
} | ||
|
||
dpv := reflect.ValueOf(dest) | ||
if dpv.Kind() != reflect.Pointer { | ||
return errors.New("destination not a pointer") | ||
} | ||
if dpv.IsNil() { | ||
return errNilPtr | ||
} | ||
|
||
if !sv.IsValid() { | ||
sv = reflect.ValueOf(src) | ||
} | ||
|
||
dv := reflect.Indirect(dpv) | ||
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { | ||
switch b := src.(type) { | ||
case []byte: | ||
dv.Set(reflect.ValueOf(cloneBytes(b))) | ||
default: | ||
dv.Set(sv) | ||
} | ||
return nil | ||
} | ||
|
||
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { | ||
dv.Set(sv.Convert(dv.Type())) | ||
return nil | ||
} | ||
|
||
// The following conversions use a string value as an intermediate representation | ||
// to convert between various numeric types. | ||
// | ||
// This also allows scanning into user defined types such as "type Int int64". | ||
// For symmetry, also check for string destination types. | ||
switch dv.Kind() { | ||
case reflect.Pointer: | ||
if src == nil { | ||
dv.Set(reflect.Zero(dv.Type())) | ||
return nil | ||
} | ||
dv.Set(reflect.New(dv.Type().Elem())) | ||
return convertAssign(dv.Interface(), src) | ||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||
if src == nil { | ||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | ||
} | ||
s := asString(src) | ||
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) | ||
if err != nil { | ||
err = strconvErr(err) | ||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | ||
} | ||
dv.SetInt(i64) | ||
return nil | ||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||
if src == nil { | ||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | ||
} | ||
s := asString(src) | ||
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) | ||
if err != nil { | ||
err = strconvErr(err) | ||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | ||
} | ||
dv.SetUint(u64) | ||
return nil | ||
case reflect.Float32, reflect.Float64: | ||
if src == nil { | ||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | ||
} | ||
s := asString(src) | ||
f64, err := strconv.ParseFloat(s, dv.Type().Bits()) | ||
if err != nil { | ||
err = strconvErr(err) | ||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | ||
} | ||
dv.SetFloat(f64) | ||
return nil | ||
case reflect.String: | ||
if src == nil { | ||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | ||
} | ||
switch v := src.(type) { | ||
case string: | ||
dv.SetString(v) | ||
return nil | ||
case []byte: | ||
dv.SetString(string(v)) | ||
return nil | ||
} | ||
} | ||
|
||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) | ||
} | ||
|
||
func (o *Option[T]) scanConvertValue(src any) error { | ||
// we try to convertAssign values that we can't directly assign because ConvertValue | ||
// will return immediately for v that is already a Value, even if it is a different | ||
// Value type than the one we expect here. | ||
var dest T | ||
if err := convertAssign(&dest, src); err == nil { | ||
o.isPresent = true | ||
o.value = dest | ||
return nil | ||
} | ||
return fmt.Errorf("failed to scan Option[T]") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
//go:build go1.22 | ||
// +build go1.22 | ||
|
||
package mo | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
) | ||
|
||
func (o *Option[T]) scanConvertValue(src any) error { | ||
// we try to convertAssign values that we can't directly assign because ConvertValue | ||
// will return immediately for v that is already a Value, even if it is a different | ||
// Value type than the one we expect here. | ||
var st sql.Null[T] | ||
if err := st.Scan(src); err == nil { | ||
o.isPresent = true | ||
o.value = st.V | ||
return nil | ||
} | ||
return fmt.Errorf("failed to scan Option[T]") | ||
} |
Oops, something went wrong.