-
Notifications
You must be signed in to change notification settings - Fork 707
/
test_transforms.py
246 lines (236 loc) · 12.8 KB
/
test_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import unittest
from sqlglot import parse_one
from sqlglot.transforms import (
eliminate_distinct_on,
eliminate_join_marks,
eliminate_qualify,
remove_precision_parameterized_types,
unalias_group,
)
class TestTransforms(unittest.TestCase):
maxDiff = None
def validate(self, transform, sql, target, dialect=None):
with self.subTest(f"{dialect} - {sql}"):
self.assertEqual(
parse_one(sql, dialect=dialect).transform(transform).sql(dialect=dialect), target
)
def test_unalias_group(self):
self.validate(
unalias_group,
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4",
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4",
)
self.validate(
unalias_group,
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), aliased_last",
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), 3",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY new_date",
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY 1",
)
self.validate(
unalias_group,
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY 1",
)
self.validate(
unalias_group,
"SELECT a AS a FROM x GROUP BY DATE(a)",
"SELECT a AS a FROM x GROUP BY DATE(a)",
)
def test_eliminate_distinct_on(self):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x",
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
"SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
"SELECT _row_number FROM (SELECT _row_number AS _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (x.a, x.b) x.a, x.b FROM x ORDER BY c DESC",
"SELECT a, b FROM (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a, x.b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) x.a, y.a FROM x CROSS JOIN y ORDER BY c DESC",
"SELECT a, a_2 FROM (SELECT x.a AS a, y.a AS a_2, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x CROSS JOIN y) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, a + b FROM x ORDER BY c DESC",
"SELECT a, _col FROM (SELECT a AS a, a + b AS _col, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) * FROM x ORDER BY c DESC",
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
'SELECT DISTINCT ON (a) a AS "A", b FROM x ORDER BY c DESC',
'SELECT "A", b FROM (SELECT a AS "A", b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1',
)
self.validate(
eliminate_distinct_on,
'SELECT DISTINCT ON (a) "A", b FROM x ORDER BY c DESC',
'SELECT "A", b FROM (SELECT "A" AS "A", b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1',
)
def test_eliminate_qualify(self):
self.validate(
eliminate_qualify,
"SELECT i, a + 1 FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p) = 1",
"SELECT i, _c FROM (SELECT i, a + 1 AS _c, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM qt) AS _t WHERE _w = 1",
)
self.validate(
eliminate_qualify,
"SELECT i FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1 AND p = 0",
"SELECT i FROM (SELECT i, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1 AND p = 0",
)
self.validate(
eliminate_qualify,
"SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
"SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
)
self.validate(
eliminate_qualify,
"SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt QUALIFY row_num = 1",
"SELECT i, p, o, row_num FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt) AS _t WHERE row_num = 1",
)
self.validate(
eliminate_qualify,
"SELECT * FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
)
self.validate(
eliminate_qualify,
"SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3 QUALIFY r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
"SELECT c2, r FROM (SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r, c1 FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3) AS _t WHERE r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
)
self.validate(
eliminate_qualify,
"SELECT x FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY p)",
"SELECT x FROM (SELECT x, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM y) AS _t WHERE _w",
)
self.validate(
eliminate_qualify,
"SELECT x AS z FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY z)",
"SELECT z FROM (SELECT x AS z, ROW_NUMBER() OVER (PARTITION BY x) AS _w, x FROM y) AS _t WHERE _w",
)
self.validate(
eliminate_qualify,
"SELECT SOME_UDF(x) AS z FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY x ORDER BY z)",
"SELECT z FROM (SELECT SOME_UDF(x) AS z, ROW_NUMBER() OVER (PARTITION BY x ORDER BY SOME_UDF(x)) AS _w, x FROM y) AS _t WHERE _w",
)
self.validate(
eliminate_qualify,
"SELECT x, t, x || t AS z FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY x ORDER BY z DESC)",
"SELECT x, t, z FROM (SELECT x, t, x || t AS z, ROW_NUMBER() OVER (PARTITION BY x ORDER BY x || t DESC) AS _w FROM y) AS _t WHERE _w",
)
self.validate(
eliminate_qualify,
"SELECT y.x AS x, y.t AS z FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY x ORDER BY x DESC, z)",
"SELECT x, z FROM (SELECT y.x AS x, y.t AS z, ROW_NUMBER() OVER (PARTITION BY y.x ORDER BY y.x DESC, y.t) AS _w, y.t FROM y) AS _t WHERE _w",
)
def test_remove_precision_parameterized_types(self):
self.validate(
remove_precision_parameterized_types,
"SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))",
"SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)",
)
def test_eliminate_join_marks(self):
for dialect in ("oracle", "redshift"):
self.validate(
eliminate_join_marks,
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x (+) = T2.x and T2.y > 5",
"SELECT T1.d, T2.c FROM T2 LEFT JOIN T1 ON T1.x = T2.x WHERE T2.y > 5",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT * FROM table1, table2 WHERE table1.col = table2.col(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT * FROM table1, table2, table3, table4 WHERE table1.col = table2.col(+) and table2.col >= table3.col(+) and table1.col = table4.col(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col LEFT JOIN table3 ON table2.col >= table3.col LEFT JOIN table4 ON table1.col = table4.col",
dialect,
)
self.validate(
eliminate_join_marks,
"SELECT * FROM table1, table2, table3 WHERE table1.col = table2.col(+) and table2.col >= table3.col(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col LEFT JOIN table3 ON table2.col >= table3.col",
dialect,
)
# 2 join marks on one side of predicate
self.validate(
eliminate_join_marks,
"SELECT * FROM table1, table2 WHERE table1.col = table2.col1(+) + table2.col2(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col1 + table2.col2",
dialect,
)
# join mark and expression
self.validate(
eliminate_join_marks,
"SELECT * FROM table1, table2 WHERE table1.col = table2.col1(+) + 25",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col1 + 25",
dialect,
)
alias = "AS " if dialect != "oracle" else ""
self.validate(
eliminate_join_marks,
"SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
f"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) {alias}table3 ON table1.id = table3.id",
dialect,
)