-
Notifications
You must be signed in to change notification settings - Fork 4
/
source.go
264 lines (227 loc) · 6.6 KB
/
source.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
package scan
import (
"context"
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
"sync"
)
var (
matchFirstCapRe = regexp.MustCompile("(.)([A-Z][a-z]+)")
matchAllCapRe = regexp.MustCompile("([a-z0-9])([A-Z])")
defaultStructMapper = newDefaultMapperSourceImpl()
)
// snakeCaseFieldFunc is a NameMapperFunc that maps struct field to snake case.
func snakeCaseFieldFunc(str string) string {
snake := matchFirstCapRe.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCapRe.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}
func newDefaultMapperSourceImpl() *mapperSourceImpl {
return &mapperSourceImpl{
structTagKey: "db",
columnSeparator: ".",
fieldMapperFn: snakeCaseFieldFunc,
scannableTypes: []reflect.Type{reflect.TypeOf((*sql.Scanner)(nil)).Elem()},
maxDepth: 3,
cache: make(map[reflect.Type]mapping),
}
}
// NewStructMapperSource creates a new Mapping object with provided list of options.
func NewStructMapperSource(opts ...MappingSourceOption) (StructMapperSource, error) {
src := newDefaultMapperSourceImpl()
for _, o := range opts {
if err := o(src); err != nil {
return nil, err
}
}
return src, nil
}
// MappingSourceOption are options to modify how a struct's mappings are interpreted
type MappingSourceOption func(src *mapperSourceImpl) error
// WithStructTagKey allows to use a custom struct tag key.
// The default tag key is `db`.
func WithStructTagKey(tagKey string) MappingSourceOption {
return func(src *mapperSourceImpl) error {
src.structTagKey = tagKey
return nil
}
}
// WithColumnSeparator allows to use a custom separator character for column name when combining nested structs.
// The default separator is "." character.
func WithColumnSeparator(separator string) MappingSourceOption {
return func(src *mapperSourceImpl) error {
src.columnSeparator = separator
return nil
}
}
// WithFieldNameMapper allows to use a custom function to map field name to column names.
// The default function maps fields names to "snake_case"
func WithFieldNameMapper(mapperFn func(string) string) MappingSourceOption {
return func(src *mapperSourceImpl) error {
src.fieldMapperFn = mapperFn
return nil
}
}
// WithScannableTypes specifies a list of interfaces that underlying database library can scan into.
// In case the destination type passed to scan implements one of those interfaces,
// scan will handle it as primitive type case i.e. simply pass the destination to the database library.
// Instead of attempting to map database columns to destination struct fields or map keys.
// In order for reflection to capture the interface type, you must pass it by pointer.
//
// For example your database library defines a scanner interface like this:
//
// type Scanner interface {
// Scan(...) error
// }
//
// You can pass it to scan this way:
// scan.WithScannableTypes((*Scanner)(nil)).
func WithScannableTypes(scannableTypes ...any) MappingSourceOption {
return func(src *mapperSourceImpl) error {
for _, stOpt := range scannableTypes {
st := reflect.TypeOf(stOpt)
if st == nil {
return fmt.Errorf("scannable type must be a pointer, got %T", stOpt)
}
if st.Kind() != reflect.Pointer {
return fmt.Errorf("scannable type must be a pointer, got %s: %s",
st.Kind(), st.String())
}
st = st.Elem()
if st.Kind() != reflect.Interface {
return fmt.Errorf("scannable type must be a pointer to an interface, got %s: %s",
st.Kind(), st.String())
}
src.scannableTypes = append(src.scannableTypes, st)
}
return nil
}
}
// mapperSourceImpl is an implementation of StructMapperSource.
type mapperSourceImpl struct {
structTagKey string
columnSeparator string
fieldMapperFn func(string) string
scannableTypes []reflect.Type
maxDepth int
cache map[reflect.Type]mapping
mutex sync.RWMutex
}
func (s *mapperSourceImpl) getMapping(typ reflect.Type) (mapping, error) {
s.mutex.RLock()
m, ok := s.cache[typ]
s.mutex.RUnlock()
if ok {
return m, nil
}
s.setMappings(typ, "", make(visited), &m, nil)
s.mutex.Lock()
s.cache[typ] = m
s.mutex.Unlock()
return m, nil
}
func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *mapping, inits [][]int, position ...int) {
count := v[typ]
if count > s.maxDepth {
return
}
v[typ] = count + 1
var hasExported bool
var isPointer bool
if typ.Kind() == reflect.Pointer {
isPointer = true
typ = typ.Elem()
}
// If it implements a scannable type, then it can be used
// as a value itself. Return it
for _, scannable := range s.scannableTypes {
if reflect.PtrTo(typ).Implements(scannable) {
*m = append(*m, mapinfo{
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
})
return
}
}
// Go through the struct fields and populate the map.
// Recursively go into any child structs, adding a prefix where necessary
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Don't consider unexported fields
if !field.IsExported() {
continue
}
// Skip columns that have the tag "-"
tag := strings.Split(field.Tag.Get(s.structTagKey), ",")[0]
if tag == "-" {
continue
}
hasExported = true
key := prefix
if !field.Anonymous {
var sep string
if prefix != "" {
sep = s.columnSeparator
}
name := tag
if tag == "" {
name = s.fieldMapperFn(field.Name)
}
key = strings.Join([]string{key, name}, sep)
}
currentIndex := append(position, i)
fieldType := field.Type
var isPointer bool
if fieldType.Kind() == reflect.Pointer {
inits = append(inits, currentIndex)
fieldType = fieldType.Elem()
isPointer = true
}
if fieldType.Kind() == reflect.Struct {
s.setMappings(field.Type, key, v.copy(), m, inits, currentIndex...)
continue
}
*m = append(*m, mapinfo{
name: key,
position: currentIndex,
init: inits,
isPointer: isPointer,
})
}
// If it has no exported field (such as time.Time) then we attempt to
// directly scan into it
if !hasExported {
*m = append(*m, mapinfo{
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
})
}
}
func filterColumns(ctx context.Context, c cols, m mapping, prefix string) (mapping, error) {
// Filter the mapping so we only ask for the available columns
filtered := make(mapping, 0, len(c))
for _, name := range c {
key := name
if prefix != "" {
if !strings.HasPrefix(name, prefix) {
continue
}
key = name[len(prefix):]
}
for _, info := range m {
if key == info.name {
info.name = name
filtered = append(filtered, info)
break
}
}
}
return filtered, nil
}