Skip to content

Commit 1a24c03

Browse files
authored
expression: correct the erroneous scalar function equivalence check (#54067)
close #53726
1 parent d1f2671 commit 1a24c03

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

pkg/expression/scalar_function.go

+3
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ func (sf *ScalarFunction) Equal(ctx EvalContext, e Expression) bool {
365365
if sf.FuncName.L != fun.FuncName.L {
366366
return false
367367
}
368+
if !sf.RetType.Equal(fun.RetType) {
369+
return false
370+
}
368371
return sf.Function.equal(ctx, fun.Function)
369372
}
370373

pkg/expression/util_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func TestSubstituteCorCol2Constant(t *testing.T) {
259259
ret, err = SubstituteCorCol2Constant(ctx, plus3)
260260
require.NoError(t, err)
261261
ans3 := newFunctionWithMockCtx(ast.Plus, ans1, col1)
262-
require.True(t, ret.Equal(ctx, ans3))
262+
require.False(t, ret.Equal(ctx, ans3))
263263
}
264264

265265
func TestPushDownNot(t *testing.T) {

pkg/parser/types/field_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (ft *FieldType) Equal(other *FieldType) bool {
289289
// because flen for them is useless.
290290
// The decimal field can be ignored if the type is int or string.
291291
tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar)
292-
flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength)
292+
flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) || ft.EvalType() == ETJson
293293
ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString
294294
partialEqual := tpEqual &&
295295
(ignoreDecimal || ft.decimal == other.decimal) &&

pkg/planner/core/issuetest/planner_issue_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,31 @@ func TestIssue43461(t *testing.T) {
5858

5959
require.NotEqual(t, is.Columns, ts.Columns)
6060
}
61+
62+
func Test53726(t *testing.T) {
63+
// test for RemoveUnnecessaryFirstRow
64+
store := testkit.CreateMockStore(t)
65+
tk := testkit.NewTestKit(t, store)
66+
tk.MustExec("use test")
67+
tk.MustExec("create table t7(c int); ")
68+
tk.MustExec("insert into t7 values (575932053), (-258025139);")
69+
tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7").
70+
Sort().Check(testkit.Rows("-258025139 -258025139", "575932053 575932053"))
71+
tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7").
72+
Check(testkit.Rows(
73+
"HashAgg_8 8000.00 root group by:Column#7, Column#8, funcs:firstrow(Column#7)->Column#3, funcs:firstrow(Column#8)->Column#4",
74+
"└─TableReader_9 8000.00 root data:HashAgg_4",
75+
" └─HashAgg_4 8000.00 cop[tikv] group by:cast(test.t7.c, bigint(22) BINARY), cast(test.t7.c, decimal(10,0) BINARY), ",
76+
" └─TableFullScan_7 10000.00 cop[tikv] table:t7 keep order:false, stats:pseudo"))
77+
78+
tk.MustExec("analyze table t7")
79+
tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7").
80+
Sort().
81+
Check(testkit.Rows("-258025139 -258025139", "575932053 575932053"))
82+
tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7").
83+
Check(testkit.Rows(
84+
"HashAgg_6 2.00 root group by:Column#13, Column#14, funcs:firstrow(Column#11)->Column#3, funcs:firstrow(Column#12)->Column#4",
85+
"└─Projection_12 2.00 root cast(test.t7.c, decimal(10,0) BINARY)->Column#11, cast(test.t7.c, bigint(22) BINARY)->Column#12, cast(test.t7.c, decimal(10,0) BINARY)->Column#13, cast(test.t7.c, bigint(22) BINARY)->Column#14",
86+
" └─TableReader_11 2.00 root data:TableFullScan_10",
87+
" └─TableFullScan_10 2.00 cop[tikv] table:t7 keep order:false"))
88+
}

0 commit comments

Comments
 (0)