diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 8ec6a13..90e1a12 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -6,8 +6,11 @@ package reflectx import ( + "database/sql" + "fmt" "reflect" "runtime" + "strconv" "strings" "sync" ) @@ -200,6 +203,192 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } +// ObjectContext provides a single layer to abstract away +// nested struct scanning functionality +type ObjectContext struct { + value reflect.Value +} + +func NewObjectContext() *ObjectContext { + return &ObjectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *ObjectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + val := FieldByIndexes(o.value, indexes) + return val + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + v := reflect.ValueOf(obj).Elem() + return v +} + +// getFieldByIndex returns a value for the field given by the struct traversal +// for the given value. +func (o *ObjectContext) getFieldByIndex(indexes []int) reflect.Value { + return FieldByIndexes(o.value, indexes) +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *ObjectContext + indexes []int +} + +// Scan implements sql.Scanner. +// This method largely mirrors the sql.convertAssign() method with some minor changes +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + + dv := FieldByIndexes(o.parent.value, o.indexes) + iface := dv.Addr().Interface() + + if scan, ok := iface.(sql.Scanner); ok { + return scan.Scan(src) + } + + sv := reflect.ValueOf(src) + + // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go + // with a few minor edits + + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(bytesClone(b))) + default: + dv.Set(sv) + } + + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) +} + +// returns internal conversion error if available +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +// converts value to it's string value +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// bytesClone returns a copy of b[:len(b)]. +// The result may have additional unused capacity. +// Clone(nil) returns nil. +// +// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version +func bytesClone(b []byte) []byte { + if b == nil { + return nil + } + return append([]byte{}, b...) +} + // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index 8259a4f..e0ef63d 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,8 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + octx := reflectx.NewObjectContext() + err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err } @@ -784,7 +785,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + octx := reflectx.NewObjectContext() + + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -951,13 +954,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) + octx := reflectx.NewObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -1023,18 +1027,20 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } + octx.NewRow(v) + for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } - f := reflectx.FieldByIndexes(v, traversal) + f := octx.FieldForIndexes(traversal) if ptrs { values[i] = f.Addr().Interface() } else { diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 91c5cba..f355d82 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -643,6 +643,110 @@ func TestNamedQueryContext(t *testing.T) { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } + + rows.Close() + + type Owner struct { + Email string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + q4 := `INSERT INTO place (id, name) VALUES (2, :name)` + _, err = db.NamedExecContext(ctx, q4, pl) + if err != nil { + log.Fatal(err) + } + + id = 2 + pp.Place.ID = id + + q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q5, pp) + if err != nil { + log.Fatal(err) + } + + pp3 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner != nil { + t.Error("Expected `Owner`, to be nil") + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } + + rows.Close() + + pp3 = &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + left JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner == nil { + t.Error("Expected `Owner`, to not be nil") + } + + if pp3.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + } + if pp3.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } }) }