Skip to content

Commit

Permalink
load: add package
Browse files Browse the repository at this point in the history
  • Loading branch information
yansal committed Aug 23, 2022
1 parent 17b3535 commit 7386a03
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 1 deletion.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/yansal/sql

go 1.13
go 1.19
60 changes: 60 additions & 0 deletions load/find.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package load

import (
"context"

"github.com/yansal/sql/build"
)

func Find[
T any,
PtrToT interface {
*T
Model
},
](ctx context.Context, db Querier, where build.Expression) ([]T, error) {
var dest []T
if err := find[T, PtrToT](ctx, db, &dest, where); err != nil {
return nil, err
}
return dest, nil
}

func find[
T any,
PtrToT interface {
*T
Model
},
](ctx context.Context, db Querier, dest *[]T, where build.Expression) error {
var (
model PtrToT
columns = model.GetColumns()
table = model.GetTable()
)
query, args := build.Select(build.Columns(columns...)...).
From(build.Ident(table)).
Where(where).
Build()

rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()

for rows.Next() {
var (
v T
vptr PtrToT = &v
)
if err := rows.Scan(vptr.GetDests()...); err != nil {
return err
}
*dest = append(*dest, v)
}
if err := rows.Err(); err != nil {
return err
}
return nil
}
66 changes: 66 additions & 0 deletions load/get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package load

import (
"context"
"database/sql"

"github.com/yansal/sql/build"
)

type Model interface {
GetColumns() []string
GetTable() string
GetDests() []any
}

type Querier interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
}

func Get[
T any,
PtrToT interface {
*T
Model
},
](ctx context.Context, db Querier, where build.Expression) (PtrToT, error) {
var (
dest T
destptr PtrToT = &dest
)
if err := get(ctx, db, destptr, where); err != nil {
return nil, err
}
return destptr, nil
}

func get(ctx context.Context, db Querier, dest Model, where build.Expression) error {
var (
columns = dest.GetColumns()
table = dest.GetTable()
)
query, args := build.Select(build.Columns(columns...)...).
From(build.Ident(table)).
Where(where).
Build()

rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()

if !rows.Next() {
return sql.ErrNoRows
}
if err := rows.Scan(dest.GetDests()...); err != nil {
return err
}
if err := rows.Close(); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
return nil
}
59 changes: 59 additions & 0 deletions load/nested.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package load

import (
"context"
"fmt"
"strings"
)

type NestedModel interface {
GetField(string) any
SetField(string, any)
}

func PreloadSliceNested[
Nested any,
PreloadDest any,
PreloadSrc any,
PtrToNested interface {
*Nested
PreloadModel
},
PtrToPreloadDest interface {
*PreloadDest
Model
},
PtrToPreloadSrc interface {
*PreloadSrc
NestedModel
},
](ctx context.Context, db Querier, srcs []PreloadSrc, destname string) error {
split := strings.Split(destname, ".")
if len(split) != 2 {
panic(fmt.Sprintf("expected 1 nested preload destname, got %q", destname))
}

var (
allnested []Nested
indexes []struct{ i, j int }
)
for i := range srcs {
var ptr PtrToPreloadSrc = &srcs[i]
nested := ptr.GetField(split[0]).([]Nested)
indexes = append(indexes, struct{ i, j int }{i: len(allnested), j: len(allnested) + len(nested)})
allnested = append(allnested, nested...)
}
if err := PreloadSlice[
PreloadDest,
Nested,
PtrToPreloadDest,
PtrToNested,
](ctx, db, allnested, split[1]); err != nil {
return err
}
for i := range srcs {
var ptr PtrToPreloadSrc = &srcs[i]
ptr.SetField(split[0], allnested[indexes[i].i:indexes[i].j])
}
return nil
}
117 changes: 117 additions & 0 deletions load/preload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package load

import (
"context"
"database/sql/driver"

"github.com/yansal/sql/build"
)

type PreloadModel interface {
GetPreloadBindValue(string) any
GetPreloadDestIdent(string) string
GetPreloadDestValue(string, any) any
SetPreloadDest(string, any)
}

func PreloadSlice[
PreloadDest any,
PreloadSrc any,
PtrToPreloadDest interface {
*PreloadDest
Model
},
PtrToPreloadSrc interface {
*PreloadSrc
PreloadModel
},
](ctx context.Context, db Querier, srcs []PreloadSrc, destname string) error {
if len(srcs) == 0 {
return nil
}
bindvaluemap := make(map[any]struct{})
for i := range srcs {
var (
srcptr PtrToPreloadSrc = &srcs[i]
bindvalue = srcptr.GetPreloadBindValue(destname)
)
if valuer, ok := bindvalue.(driver.Valuer); ok {
value, err := valuer.Value()
if err != nil {
return err
}
if value == nil {
continue
}
bindvalue = value
}
bindvaluemap[bindvalue] = struct{}{}
}
if len(bindvaluemap) == 0 {
return nil
}
bindvalues := make([]any, 0, len(bindvaluemap))
for v := range bindvaluemap {
bindvalues = append(bindvalues, v)
}

var (
srcmodel PtrToPreloadSrc
where = build.Ident(srcmodel.GetPreloadDestIdent(destname)).In(build.Bind(bindvalues))
)
dests, err := Find[PreloadDest, PtrToPreloadDest](ctx, db, where)
if err != nil {
return err
}

destmap := make(map[any][]PreloadDest)
for i := range dests {
destvalue := srcmodel.GetPreloadDestValue(destname, dests[i])
destmap[destvalue] = append(destmap[destvalue], dests[i])
}
for i := range srcs {
var (
srcptr PtrToPreloadSrc = &srcs[i]
bindvalue = srcptr.GetPreloadBindValue(destname)
)
if valuer, ok := bindvalue.(driver.Valuer); ok {
value, err := valuer.Value()
if err != nil {
return err
}
if value == nil {
continue
}
bindvalue = value
}
if v := destmap[bindvalue]; len(v) > 0 {
srcptr.SetPreloadDest(destname, v)
}
}
return nil
}

func PreloadPtr[
PreloadDest any,
PtrToPreloadDest interface {
*PreloadDest
Model
},
PreloadSrc any,
PtrToPreloadSrc interface {
*PreloadSrc
PreloadModel
},
](ctx context.Context, db Querier, srcptr PtrToPreloadSrc, destname string) error {
srcs := []PreloadSrc{*srcptr}
if err := PreloadSlice[
PreloadDest,
PreloadSrc,
PtrToPreloadDest,
PtrToPreloadSrc,
](ctx, db, srcs, destname); err != nil {
return err
}
*srcptr = srcs[0]
return nil
}

0 comments on commit 7386a03

Please sign in to comment.