@@ -48,58 +48,42 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
4848 }
4949 }
5050
51- result , err := exp .TransformExpressions (func (e sql.Expression ) (sql.Expression , error ) {
52- var result sql.Expression
53-
54- // No need to wrap expressions that already validate times, such as
55- // convert, date_add, etc and those expressions whose Type method
56- // cannot be called because they are placeholders.
57- switch e .(type ) {
58- case * expression.Convert ,
59- * expression.Arithmetic ,
60- * function.DateAdd ,
61- * function.DateSub ,
62- * expression.Star ,
63- * expression.DefaultColumn ,
64- * expression.Alias :
65- return e , nil
66- default :
67- // If it's a replacement, just replace it with the correct GetField
68- // because we know that it's already converted to a correct date
69- // and there is no point to do so again.
70- if gf , ok := e .(* expression.GetField ); ok {
71- if name , ok := replacements [tableCol {gf .Table (), gf .Name ()}]; ok {
72- return expression .NewGetField (gf .Index (), gf .Type (), name , gf .IsNullable ()), nil
73- }
74- }
75-
76- switch e .Type () {
77- case sql .Date :
78- result = expression .NewConvert (e , expression .ConvertToDate )
79- case sql .Timestamp :
80- result = expression .NewConvert (e , expression .ConvertToDatetime )
81- default :
82- result = e
51+ var result sql.Node
52+ var err error
53+ switch exp := exp .(type ) {
54+ case * plan.GroupBy :
55+ var aggregate = make ([]sql.Expression , len (exp .Aggregate ))
56+ for i , a := range exp .Aggregate {
57+ agg , err := a .TransformUp (func (e sql.Expression ) (sql.Expression , error ) {
58+ return addDateConvert (e , exp , replacements , nodeReplacements , expressions , true )
59+ })
60+ if err != nil {
61+ return nil , err
8362 }
63+ aggregate [i ] = agg
8464 }
8565
86- // Only do this if it's a root expression in a project or group by.
87- switch exp .(type ) {
88- case * plan.Project , * plan.GroupBy :
89- // If it was originally a GetField, and it's not anymore it's
90- // because we wrapped it in a convert. We need to make it an alias
91- // and propagate the changes up the chain.
92- if gf , ok := e .(* expression.GetField ); ok && expressions [e .String ()] {
93- if _ , ok := result .(* expression.GetField ); ! ok {
94- name := fmt .Sprintf ("%s__%s" , gf .Table (), gf .Name ())
95- result = expression .NewAlias (result , name )
96- nodeReplacements [tableCol {gf .Table (), gf .Name ()}] = name
97- }
66+ var grouping = make ([]sql.Expression , len (exp .Grouping ))
67+ for i , g := range exp .Grouping {
68+ gr , err := g .TransformUp (func (e sql.Expression ) (sql.Expression , error ) {
69+ return addDateConvert (e , exp , replacements , nodeReplacements , expressions , false )
70+ })
71+ if err != nil {
72+ return nil , err
9873 }
74+ grouping [i ] = gr
9975 }
10076
101- return result , nil
102- })
77+ result = plan .NewGroupBy (aggregate , grouping , exp .Child )
78+ default :
79+ result , err = exp .TransformExpressions (func (e sql.Expression ) (sql.Expression , error ) {
80+ return addDateConvert (e , n , replacements , nodeReplacements , expressions , true )
81+ })
82+ }
83+
84+ if err != nil {
85+ return nil , err
86+ }
10387
10488 // We're done with this node, so copy all the replacements found in
10589 // this node to the global replacements in order to make the necesssary
@@ -111,3 +95,62 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
11195 return result , err
11296 })
11397}
98+
99+ func addDateConvert (
100+ e sql.Expression ,
101+ node sql.Node ,
102+ replacements , nodeReplacements map [tableCol ]string ,
103+ expressions map [string ]bool ,
104+ aliasRootProjections bool ,
105+ ) (sql.Expression , error ) {
106+ var result sql.Expression
107+
108+ // No need to wrap expressions that already validate times, such as
109+ // convert, date_add, etc and those expressions whose Type method
110+ // cannot be called because they are placeholders.
111+ switch e .(type ) {
112+ case * expression.Convert ,
113+ * expression.Arithmetic ,
114+ * function.DateAdd ,
115+ * function.DateSub ,
116+ * expression.Star ,
117+ * expression.DefaultColumn ,
118+ * expression.Alias :
119+ return e , nil
120+ default :
121+ // If it's a replacement, just replace it with the correct GetField
122+ // because we know that it's already converted to a correct date
123+ // and there is no point to do so again.
124+ if gf , ok := e .(* expression.GetField ); ok {
125+ if name , ok := replacements [tableCol {gf .Table (), gf .Name ()}]; ok {
126+ return expression .NewGetField (gf .Index (), gf .Type (), name , gf .IsNullable ()), nil
127+ }
128+ }
129+
130+ switch e .Type () {
131+ case sql .Date :
132+ result = expression .NewConvert (e , expression .ConvertToDate )
133+ case sql .Timestamp :
134+ result = expression .NewConvert (e , expression .ConvertToDatetime )
135+ default :
136+ result = e
137+ }
138+ }
139+
140+ // Only do this if it's a root expression in a project or group by.
141+ switch node .(type ) {
142+ case * plan.Project , * plan.GroupBy :
143+ // If it was originally a GetField, and it's not anymore it's
144+ // because we wrapped it in a convert. We need to make it an alias
145+ // and propagate the changes up the chain.
146+ if gf , ok := e .(* expression.GetField ); ok && expressions [e .String ()] && aliasRootProjections {
147+ if _ , ok := result .(* expression.GetField ); ! ok {
148+ name := fmt .Sprintf ("%s__%s" , gf .Table (), gf .Name ())
149+ result = expression .NewAlias (result , name )
150+ nodeReplacements [tableCol {gf .Table (), gf .Name ()}] = name
151+ }
152+ }
153+ }
154+
155+ return result , nil
156+ }
0 commit comments