diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index a325842dd4993..5d44c9b3d05a6 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -360,6 +360,9 @@ func (sf *ScalarFunction) Equal(ctx EvalContext, e Expression) bool { if sf.FuncName.L != fun.FuncName.L { return false } + if !sf.RetType.Equal(fun.RetType) { + return false + } return sf.Function.equal(ctx, fun.Function) } diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index cb922c0948cd1..924cf8905c4ee 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -259,7 +259,7 @@ func TestSubstituteCorCol2Constant(t *testing.T) { ret, err = SubstituteCorCol2Constant(ctx, plus3) require.NoError(t, err) ans3 := newFunctionWithMockCtx(ast.Plus, ans1, col1) - require.True(t, ret.Equal(ctx, ans3)) + require.False(t, ret.Equal(ctx, ans3)) } func TestPushDownNot(t *testing.T) { diff --git a/pkg/parser/types/field_type.go b/pkg/parser/types/field_type.go index 2e80bbf3c7d2b..2befced6393bb 100644 --- a/pkg/parser/types/field_type.go +++ b/pkg/parser/types/field_type.go @@ -289,7 +289,7 @@ func (ft *FieldType) Equal(other *FieldType) bool { // because flen for them is useless. // The decimal field can be ignored if the type is int or string. tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar) - flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) + flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) || ft.EvalType() == ETJson ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString partialEqual := tpEqual && (ignoreDecimal || ft.decimal == other.decimal) && diff --git a/pkg/planner/core/issuetest/planner_issue_test.go b/pkg/planner/core/issuetest/planner_issue_test.go index be9e5ef29d971..772a66b644388 100644 --- a/pkg/planner/core/issuetest/planner_issue_test.go +++ b/pkg/planner/core/issuetest/planner_issue_test.go @@ -57,3 +57,31 @@ func TestIssue43461(t *testing.T) { require.NotEqual(t, is.Columns, ts.Columns) } + +func Test53726(t *testing.T) { + // test for RemoveUnnecessaryFirstRow + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t7(c int); ") + tk.MustExec("insert into t7 values (575932053), (-258025139);") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort().Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_8 8000.00 root group by:Column#7, Column#8, funcs:firstrow(Column#7)->Column#3, funcs:firstrow(Column#8)->Column#4", + "└─TableReader_9 8000.00 root data:HashAgg_4", + " └─HashAgg_4 8000.00 cop[tikv] group by:cast(test.t7.c, bigint(22) BINARY), cast(test.t7.c, decimal(10,0) BINARY), ", + " └─TableFullScan_7 10000.00 cop[tikv] table:t7 keep order:false, stats:pseudo")) + + tk.MustExec("analyze table t7") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort(). + Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_6 2.00 root group by:Column#13, Column#14, funcs:firstrow(Column#11)->Column#3, funcs:firstrow(Column#12)->Column#4", + "└─Projection_12 2.00 root cast(test.t7.c, decimal(10,0) BINARY)->Column#11, cast(test.t7.c, bigint(22) BINARY)->Column#12, cast(test.t7.c, decimal(10,0) BINARY)->Column#13, cast(test.t7.c, bigint(22) BINARY)->Column#14", + " └─TableReader_11 2.00 root data:TableFullScan_10", + " └─TableFullScan_10 2.00 cop[tikv] table:t7 keep order:false")) +}