-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreplace.go
107 lines (97 loc) · 2.58 KB
/
replace.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package querydigest
import (
"bytes"
"fmt"
"io"
"log"
"sync"
"time"
"github.com/akito0107/xsqlparser"
"github.com/akito0107/xsqlparser/dialect"
"github.com/akito0107/xsqlparser/sqlast"
"github.com/akito0107/xsqlparser/sqlastutil"
"github.com/akito0107/xsqlparser/sqltoken"
)
var tokensPool = sync.Pool{
New: func() interface{} {
return make([]*sqltoken.Token, 0, 2048)
},
}
var tokenizerPool = sync.Pool{
New: func() interface{} {
return sqltoken.NewTokenizerWithOptions(nil, sqltoken.Dialect(&dialect.MySQLDialect{}), sqltoken.DisableParseComment())
},
}
func ReplaceWithZeroValue(src []byte) (string, error) {
// FIXME evil work around
defer func() {
if err := recover(); err != nil {
// log.Printf("fatal err: %v", err)
log.Printf("fatal err")
return
}
}()
tokenizer := tokenizerPool.Get().(*sqltoken.Tokenizer)
tokenizer.Line = 1
tokenizer.Col = 1
tokenizer.Scanner.Init(bytes.NewReader(src))
defer tokenizerPool.Put(tokenizer)
tokset := tokensPool.Get().([]*sqltoken.Token)
tokset = tokset[:0]
defer func() {
tokensPool.Put(tokset)
}()
for {
var tok *sqltoken.Token
if len(tokset) < cap(tokset) {
tok = tokset[:len(tokset)+1][len(tokset)]
}
if tok == nil {
tok = &sqltoken.Token{}
}
t, err := tokenizer.Scan(tok)
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("tokenize failed src: %s : %w", string(src), err)
}
if t == nil {
continue
}
tokset = append(tokset, tok)
}
parser := xsqlparser.NewParserWithOptions()
parser.SetTokens(tokset)
stmt, err := parser.ParseStatement()
if err != nil {
log.Printf("Parse failed: invalied sql: %s \n", src[:50])
return "", err
}
res := sqlastutil.Apply(stmt, func(cursor *sqlastutil.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlast.LongValue:
cursor.Replace(sqlast.NewLongValue(0))
case *sqlast.DoubleValue:
cursor.Replace(sqlast.NewDoubleValue(0))
case *sqlast.BooleanValue:
cursor.Replace(sqlast.NewBooleanValue(true))
case *sqlast.SingleQuotedString:
cursor.Replace(sqlast.NewSingleQuotedString(""))
case *sqlast.TimestampValue:
cursor.Replace(sqlast.NewTimestampValue(time.Date(1970, 1, 1, 0, 0, 0, 0, nil)))
case *sqlast.TimeValue:
cursor.Replace(sqlast.NewTimeValue(time.Date(1970, 1, 1, 0, 0, 0, 0, nil)))
case *sqlast.DateTimeValue:
cursor.Replace(sqlast.NewDateTimeValue(time.Date(1970, 1, 1, 0, 0, 0, 0, nil)))
case *sqlast.InList:
cursor.Replace(&sqlast.InList{
Expr: node.Expr,
Negated: node.Negated,
RParen: node.RParen,
})
}
return true
}, nil)
return res.ToSQLString(), nil
}