From 08d14d201de3053667971f9652f3c000582930c3 Mon Sep 17 00:00:00 2001 From: mtsoltan Date: Sun, 14 Jul 2024 16:48:14 +0900 Subject: [PATCH 1/4] Fix option to handle type conversion. --- .gitignore | 1 + option.go | 12 +++++++++++- option_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3aa3a0a..7b9b3e7 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ go.work cover.out cover.html .vscode +.idea \ No newline at end of file diff --git a/option.go b/option.go index 9955094..5212556 100644 --- a/option.go +++ b/option.go @@ -297,6 +297,16 @@ func (o *Option[T]) Scan(src any) error { } } + // we try to convertAssign values that we can't directly assign because ConvertValue + // will return immediately for v that is already a Value, even if it is a different + // Value type than the one we expect here. + var st sql.Null[T] + if err := st.Scan(src); err == nil { + o.isPresent = true + o.value = st.V + return nil + } + return fmt.Errorf("failed to scan Option[T]") } @@ -306,7 +316,7 @@ func (o Option[T]) Value() (driver.Value, error) { return nil, nil } - return o.value, nil + return driver.DefaultParameterConverter.ConvertValue(o.value) } // leftValue returns an error if the Option is None, otherwise nil diff --git a/option_test.go b/option_test.go index 55835ea..0969f8b 100644 --- a/option_test.go +++ b/option_test.go @@ -407,6 +407,25 @@ func TestOptionScan(t *testing.T) { is.Equal(err2Exp, err2) } +func TestOptionScanWithPossibleConvert(t *testing.T) { + is := assert.New(t) + + // As passed by the sql package in some cases, src is a []byte. + // https://github.com/golang/go/blob/071b8d51c1a70fa6b12f0bed2e93370e193333fd/src/database/sql/convert.go#L396 + src1 := []byte{65, 66, 67} + dest1 := None[string]() + src2 := int32(32) + dest2 := None[int]() + + err1 := dest1.Scan(src1) + err2 := dest2.Scan(src2) + + is.Nil(err1) + is.Equal(Some("ABC"), dest1) + is.Nil(err2) + is.Equal(Some(32), dest2) +} + func TestOptionValue(t *testing.T) { is := assert.New(t) @@ -425,6 +444,17 @@ func TestOptionValue(t *testing.T) { is.Nil(err2) } +func TestOptionValueWithPossibleConvert(t *testing.T) { + is := assert.New(t) + + opt := Some(uint32(42)) + expected := int64(42) + + value, err := opt.Value() + is.Nil(err) + is.Equal(expected, value) +} + type SomeScanner struct { Cool bool Some int From ef7a827b54634067f4503675656e33284f49a814 Mon Sep 17 00:00:00 2001 From: mtsoltan Date: Tue, 16 Jul 2024 13:05:55 +0900 Subject: [PATCH 2/4] In 1.18, convertAssign has no way to use it... --- option.go | 12 +- option_go118.go | 321 ++++++++++++++++++++++++++++++++++++++++++++++++ option_go122.go | 22 ++++ 3 files changed, 344 insertions(+), 11 deletions(-) create mode 100644 option_go118.go create mode 100644 option_go122.go diff --git a/option.go b/option.go index 5212556..dc91f7a 100644 --- a/option.go +++ b/option.go @@ -297,17 +297,7 @@ func (o *Option[T]) Scan(src any) error { } } - // we try to convertAssign values that we can't directly assign because ConvertValue - // will return immediately for v that is already a Value, even if it is a different - // Value type than the one we expect here. - var st sql.Null[T] - if err := st.Scan(src); err == nil { - o.isPresent = true - o.value = st.V - return nil - } - - return fmt.Errorf("failed to scan Option[T]") + return o.scanConvertValue(src) } // Value implements the driver Valuer interface. diff --git a/option_go118.go b/option_go118.go new file mode 100644 index 0000000..adbbecc --- /dev/null +++ b/option_go118.go @@ -0,0 +1,321 @@ +//go:build !go1.22 +// +build !go1.22 + +package mo + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +func asString(src any) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +// convertAssignRows copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. If rows is passed in, the rows will +// be used as the parent for any cursor values converted from a +// driver.Rows to a *Rows. +func convertAssign(dest, src any) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *any: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case nil: + switch d := dest.(type) { + case *any: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *any: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Pointer { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Pointer: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func (o *Option[T]) scanConvertValue(src any) error { + // we try to convertAssign values that we can't directly assign because ConvertValue + // will return immediately for v that is already a Value, even if it is a different + // Value type than the one we expect here. + var dest T + if err := convertAssign(&dest, src); err == nil { + o.isPresent = true + o.value = dest + return nil + } + return fmt.Errorf("failed to scan Option[T]") +} diff --git a/option_go122.go b/option_go122.go new file mode 100644 index 0000000..4f51d37 --- /dev/null +++ b/option_go122.go @@ -0,0 +1,22 @@ +//go:build go1.22 +// +build go1.22 + +package mo + +import ( + "database/sql" + "fmt" +) + +func (o *Option[T]) scanConvertValue(src any) error { + // we try to convertAssign values that we can't directly assign because ConvertValue + // will return immediately for v that is already a Value, even if it is a different + // Value type than the one we expect here. + var st sql.Null[T] + if err := st.Scan(src); err == nil { + o.isPresent = true + o.value = st.V + return nil + } + return fmt.Errorf("failed to scan Option[T]") +} From 70ca3506564e73c237818e7498831dc1781244ce Mon Sep 17 00:00:00 2001 From: Samuel Berthe Date: Tue, 16 Jul 2024 11:37:03 +0200 Subject: [PATCH 3/4] Update option_go118.go --- option_go118.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/option_go118.go b/option_go118.go index adbbecc..4e10889 100644 --- a/option_go118.go +++ b/option_go118.go @@ -3,6 +3,11 @@ package mo +// +// sql.Null[T] has been introduce in go1.22 +// This file is a copy of stdlib and ensure retro-compatibility. +// + import ( "database/sql" "database/sql/driver" From bc7453ef32b5de7c6f4e4b0452cef3f969bb84e6 Mon Sep 17 00:00:00 2001 From: Samuel Berthe Date: Tue, 16 Jul 2024 11:38:08 +0200 Subject: [PATCH 4/4] Update option_go118.go --- option_go118.go | 1 + 1 file changed, 1 insertion(+) diff --git a/option_go118.go b/option_go118.go index 4e10889..d1f9d4d 100644 --- a/option_go118.go +++ b/option_go118.go @@ -6,6 +6,7 @@ package mo // // sql.Null[T] has been introduce in go1.22 // This file is a copy of stdlib and ensure retro-compatibility. +// See https://github.com/samber/mo/pull/49 // import (