Skip to content

Commit

Permalink
Update reflectx to allow for optional nested structs
Browse files Browse the repository at this point in the history
Nested structs are now only instantiated when one of the database columns in that nested struct is not nil. This allows objects scanned in left/outer joins to keep their natural types (instead of setting everything to NullableX).

Example:

select house.id, owner.*,
from house
left join owner on owner.id = house.owner

type House struct {
 ID      int
 Owner   *Person // if left join gives nulls, Owner will be nil
}

type Owner struct {
 ID int  // no need to set this to sql.NullInt
}
  • Loading branch information
ntbosscher committed Jan 23, 2023
1 parent 28212d4 commit 81db673
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 16 deletions.
190 changes: 189 additions & 1 deletion reflectx/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
// allows for Go-compatible named attribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
//
package reflectx

import (
"database/sql"
"fmt"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
)
Expand Down Expand Up @@ -201,6 +203,192 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in
return nil
}

// ObjectContext provides a single layer to abstract away
// nested struct scanning functionality
type ObjectContext struct {
value reflect.Value
}

func NewObjectContext() *ObjectContext {
return &ObjectContext{}
}

// NewRow updates the object reference.
// This ensures all columns point to the same object
func (o *ObjectContext) NewRow(value reflect.Value) {
o.value = value
}

// FieldForIndexes returns the value for address. If the address is a nested struct,
// a nestedFieldScanner is returned instead of the standard value reference
func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value {
if len(indexes) == 1 {
val := FieldByIndexes(o.value, indexes)
return val
}

obj := &nestedFieldScanner{
parent: o,
indexes: indexes,
}

v := reflect.ValueOf(obj).Elem()
return v
}

// getFieldByIndex returns a value for the field given by the struct traversal
// for the given value.
func (o *ObjectContext) getFieldByIndex(indexes []int) reflect.Value {
return FieldByIndexes(o.value, indexes)
}

// nestedFieldScanner will only forward the Scan to the nested value if
// the database value is not nil.
type nestedFieldScanner struct {
parent *ObjectContext
indexes []int
}

// Scan implements sql.Scanner.
// This method largely mirrors the sql.convertAssign() method with some minor changes
func (o *nestedFieldScanner) Scan(src interface{}) error {
if src == nil {
return nil
}

dv := FieldByIndexes(o.parent.value, o.indexes)
iface := dv.Addr().Interface()

if scan, ok := iface.(sql.Scanner); ok {
return scan.Scan(src)
}

sv := reflect.ValueOf(src)

// below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
// with a few minor edits

if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(bytesClone(b)))
default:
dv.Set(sv)
}

return nil
}

if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}

// The following conversions use a string value as an intermediate representation
// to convert between various numeric types.
//
// This also allows scanning into user defined types such as "type Int int64".
// For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
switch v := src.(type) {
case string:
dv.SetString(v)
return nil
case []byte:
dv.SetString(string(v))
return nil
}
}

return fmt.Errorf("don't know how to parse type %T -> %T", src, iface)
}

// returns internal conversion error if available
// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}

// converts value to it's string value
// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
func asString(src interface{}) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}

// bytesClone returns a copy of b[:len(b)].
// The result may have additional unused capacity.
// Clone(nil) returns nil.
//
// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version
func bytesClone(b []byte) []byte {
if b == nil {
return nil
}
return append([]byte{}, b...)
}

// FieldByIndexes returns a value for the field given by the struct traversal
// for the given value.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
Expand Down
29 changes: 18 additions & 11 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func mapper() *reflectx.Mapper {

// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. Something is scannable if:
// * it is not a struct
// * it implements sql.Scanner
// * it has no exported fields
// - it is not a struct
// - it implements sql.Scanner
// - it has no exported fields
func isScannable(t reflect.Type) bool {
if reflect.PtrTo(t).Implements(_scannerInterface) {
return true
Expand Down Expand Up @@ -621,7 +621,8 @@ func (r *Rows) StructScan(dest interface{}) error {
r.started = true
}

err := fieldsByTraversal(v, r.fields, r.values, true)
octx := reflectx.NewObjectContext()
err := fieldsByTraversal(octx, v, r.fields, r.values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -781,7 +782,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
}
values := make([]interface{}, len(columns))

err = fieldsByTraversal(v, fields, values, true)
octx := reflectx.NewObjectContext()

err = fieldsByTraversal(octx, v, fields, values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -884,9 +887,9 @@ func structOnlyError(t reflect.Type) error {
// then each row must only have one column which can scan into that type. This
// allows you to do something like:
//
// rows, _ := db.Query("select id from people;")
// var ids []int
// scanAll(rows, &ids, false)
// rows, _ := db.Query("select id from people;")
// var ids []int
// scanAll(rows, &ids, false)
//
// and ids will be a list of the id results. I realize that this is a desirable
// interface to expose to users, but for now it will only be exposed via changes
Expand Down Expand Up @@ -948,13 +951,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
values = make([]interface{}, len(columns))
octx := reflectx.NewObjectContext()

for rows.Next() {
// create a new struct type (which returns PtrTo) and indirect it
vp = reflect.New(base)
v = reflect.Indirect(vp)

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(octx, v, fields, values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -1020,18 +1024,21 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}

octx.NewRow(v)

for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)

f := octx.FieldForIndexes(traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
Expand Down
Loading

0 comments on commit 81db673

Please sign in to comment.