-
Notifications
You must be signed in to change notification settings - Fork 20
/
mold.go
356 lines (305 loc) · 9.56 KB
/
mold.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
package mold
import (
"context"
"fmt"
"reflect"
"strings"
"time"
)
var (
timeType = reflect.TypeOf(time.Time{})
restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
)
// TODO - ensure StructLevel and Func get passed an interface and not *Transform directly
// Transform represents a subset of the current *Transformer that is executing the current transformation.
type Transform interface {
Struct(ctx context.Context, v interface{}) error
Field(ctx context.Context, v interface{}, tags string) error
}
// Func defines a transform function for use.
type Func func(ctx context.Context, fl FieldLevel) error
// StructLevelFunc accepts all values needed for struct level manipulation.
//
// Why does this exist? For structs for which you may not have access or rights to add tags too,
// from other packages your using.
type StructLevelFunc func(ctx context.Context, sl StructLevel) error
// InterceptorFunc is a way to intercept custom types to redirect the functions to be applied to an inner typ/value.
// eg. sql.NullString, the manipulation should be done on the inner string.
type InterceptorFunc func(current reflect.Value) (inner reflect.Value)
// Transformer is the base controlling object which contains
// all necessary information
type Transformer struct {
tagName string
aliases map[string]string
transformations map[string]Func
structLevelFuncs map[reflect.Type]StructLevelFunc
interceptors map[reflect.Type]InterceptorFunc
cCache *structCache
tCache *tagCache
}
// New creates a new Transform object with default tag name of 'mold'
func New() *Transformer {
tc := new(tagCache)
tc.m.Store(make(map[string]*cTag))
sc := new(structCache)
sc.m.Store(make(map[reflect.Type]*cStruct))
return &Transformer{
tagName: "mold",
aliases: make(map[string]string),
transformations: make(map[string]Func),
interceptors: make(map[reflect.Type]InterceptorFunc),
cCache: sc,
tCache: tc,
}
}
// SetTagName sets the given tag name to be used.
// Default is "trans"
func (t *Transformer) SetTagName(tagName string) {
t.tagName = tagName
}
// Register adds a transformation with the given tag
//
// NOTES:
// - if the key already exists, the previous transformation function will be replaced.
// - this method is not thread-safe it is intended that these all be registered before hand
func (t *Transformer) Register(tag string, fn Func) {
if len(tag) == 0 {
panic("Function Key cannot be empty")
}
if fn == nil {
panic("Function cannot be empty")
}
_, ok := restrictedTags[tag]
if ok || strings.ContainsAny(tag, restrictedTagChars) {
panic(fmt.Sprintf(restrictedTagErr, tag))
}
t.transformations[tag] = fn
}
// RegisterAlias registers a mapping of a single transform tag that
// defines a common or complex set of transformations to simplify adding transforms
// to structs.
//
// NOTE: this function is not thread-safe it is intended that these all be registered before hand
func (t *Transformer) RegisterAlias(alias, tags string) {
if len(alias) == 0 {
panic("Alias cannot be empty")
}
if len(tags) == 0 {
panic("Aliased tags cannot be empty")
}
_, ok := restrictedTags[alias]
if ok || strings.ContainsAny(alias, restrictedTagChars) {
panic(fmt.Sprintf(restrictedAliasErr, alias))
}
t.aliases[alias] = tags
}
// RegisterStructLevel registers a StructLevelFunc against a number of types.
// Why does this exist? For structs for which you may not have access or rights to add tags too,
// from other packages your using.
//
// NOTES:
// - this method is not thread-safe it is intended that these all be registered prior to any validation
func (t *Transformer) RegisterStructLevel(fn StructLevelFunc, types ...interface{}) {
if t.structLevelFuncs == nil {
t.structLevelFuncs = make(map[reflect.Type]StructLevelFunc)
}
for _, typ := range types {
t.structLevelFuncs[reflect.TypeOf(typ)] = fn
}
}
// RegisterInterceptor registers a new interceptor functions agains one or more types.
// This InterceptorFunc allows one to intercept the incoming to to redirect the application of modifications
// to an inner type/value.
//
// eg. sql.NullString
func (t *Transformer) RegisterInterceptor(fn InterceptorFunc, types ...interface{}) {
for _, typ := range types {
t.interceptors[reflect.TypeOf(typ)] = fn
}
}
// Struct applies transformations against the provided struct
func (t *Transformer) Struct(ctx context.Context, v interface{}) error {
orig := reflect.ValueOf(v)
if orig.Kind() != reflect.Ptr || orig.IsNil() {
return &ErrInvalidTransformValue{typ: reflect.TypeOf(v), fn: "Struct"}
}
val := orig.Elem()
typ := val.Type()
if val.Kind() != reflect.Struct || val.Type() == timeType {
return &ErrInvalidTransformation{typ: reflect.TypeOf(v)}
}
return t.setByStruct(ctx, orig, val, typ)
}
func (t *Transformer) setByStruct(ctx context.Context, parent, current reflect.Value, typ reflect.Type) (err error) {
cs, ok := t.cCache.Get(typ)
if !ok {
if cs, err = t.extractStructCache(current); err != nil {
return
}
}
// run is struct has a corresponding struct level transformation
if cs.fn != nil {
if err = cs.fn(ctx, structLevel{
transformer: t,
parent: parent,
current: current,
}); err != nil {
return
}
}
var f *cField
for i := 0; i < len(cs.fields); i++ {
f = cs.fields[i]
if err = t.setByField(ctx, current.Field(f.idx), f.cTags); err != nil {
return
}
}
return nil
}
// Field applies the provided transformations against the variable
func (t *Transformer) Field(ctx context.Context, v interface{}, tags string) (err error) {
if len(tags) == 0 || tags == ignoreTag {
return nil
}
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr || val.IsNil() {
return &ErrInvalidTransformValue{typ: reflect.TypeOf(v), fn: "Field"}
}
val = val.Elem()
// find cached tag
ctag, ok := t.tCache.Get(tags)
if !ok {
t.tCache.lock.Lock()
// could have been multiple trying to access, but once first is done this ensures tag
// isn't parsed again.
ctag, ok = t.tCache.Get(tags)
if !ok {
if ctag, _, err = t.parseFieldTagsRecursive(tags, "", "", false); err != nil {
t.tCache.lock.Unlock()
return
}
t.tCache.Set(tags, ctag)
}
t.tCache.lock.Unlock()
}
err = t.setByField(ctx, val, ctag)
return
}
func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cTag) (err error) {
current, kind := t.extractType(orig)
if ct != nil && ct.hasTag {
for ct != nil {
switch ct.typeof {
case typeEndKeys:
return
case typeDive:
ct = ct.next
switch kind {
case reflect.Slice, reflect.Array:
err = t.setByIterable(ctx, current, ct)
case reflect.Map:
err = t.setByMap(ctx, current, ct)
case reflect.Ptr:
innerKind := current.Type().Elem().Kind()
if innerKind == reflect.Slice || innerKind == reflect.Map {
// is a nil pointer to a slice or map, nothing to do.
return nil
}
// not a valid use of the dive tag
fallthrough
default:
err = ErrInvalidDive
}
return
default:
if !current.CanAddr() {
newVal := reflect.New(current.Type()).Elem()
newVal.Set(current)
if err = ct.fn(ctx, fieldLevel{
transformer: t,
parent: orig,
current: newVal,
param: ct.param,
}); err != nil {
return
}
orig.Set(reflect.Indirect(newVal))
current, kind = t.extractType(orig)
} else {
if err = ct.fn(ctx, fieldLevel{
transformer: t,
parent: orig,
current: current,
param: ct.param,
}); err != nil {
return
}
// value could have been changed or reassigned
current, kind = t.extractType(current)
}
ct = ct.next
}
}
}
// need to do this again because one of the previous
// sets could have set a struct value, where it was a
// nil pointer before
orig2 := current
current, kind = t.extractType(current)
if kind == reflect.Struct {
typ := current.Type()
if typ == timeType {
return
}
if !current.CanAddr() {
newVal := reflect.New(typ).Elem()
newVal.Set(current)
if err = t.setByStruct(ctx, orig, newVal, typ); err != nil {
return
}
orig.Set(reflect.Indirect(newVal))
return
}
err = t.setByStruct(ctx, orig2, current, typ)
}
return
}
func (t *Transformer) setByIterable(ctx context.Context, current reflect.Value, ct *cTag) (err error) {
for i := 0; i < current.Len(); i++ {
if err = t.setByField(ctx, current.Index(i), ct); err != nil {
return
}
}
return
}
func (t *Transformer) setByMap(ctx context.Context, current reflect.Value, ct *cTag) error {
for _, key := range current.MapKeys() {
newVal := reflect.New(current.Type().Elem()).Elem()
newVal.Set(current.MapIndex(key))
if ct != nil && ct.typeof == typeKeys && ct.keys != nil {
// remove current map key as we may be changing it
// and re-add to the map afterwards
current.SetMapIndex(key, reflect.Value{})
newKey := reflect.New(current.Type().Key()).Elem()
newKey.Set(key)
key = newKey
// handle map key
if err := t.setByField(ctx, key, ct.keys); err != nil {
return err
}
// can be nil when just keys being validated
if ct.next != nil {
if err := t.setByField(ctx, newVal, ct.next); err != nil {
return err
}
}
} else {
if err := t.setByField(ctx, newVal, ct); err != nil {
return err
}
}
current.SetMapIndex(key, newVal)
}
return nil
}