Skip to content

Commit

Permalink
Merge branch 'release-5.2' of github.com:pingcap/tidb into release-5.…
Browse files Browse the repository at this point in the history
…2-8cf847a57514
  • Loading branch information
XuHuaiyu committed Sep 20, 2022
2 parents af86a17 + 071ffd9 commit 6ebc88d
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 70 deletions.
5 changes: 2 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ type baseAggFunc struct {
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
// retTp means the target type of the final agg should return.
retTp *types.FieldType
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {
Expand Down
32 changes: 5 additions & 27 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"fmt"
"strconv"

"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -194,6 +193,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}

// If HasDistinct and mode is CompleteMode or Partial1Mode, we should
Expand Down Expand Up @@ -253,13 +253,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -287,16 +283,8 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -340,13 +328,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down Expand Up @@ -392,16 +375,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
isMax: isMax,
collator: collate.GetCollator(aggFuncDesc.RetTp.Collate),
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down
20 changes: 18 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -71,7 +73,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -259,7 +268,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -475,7 +477,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of first_row should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
10 changes: 9 additions & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -813,7 +814,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of max or min should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -168,7 +170,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of sum should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
38 changes: 38 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9247,6 +9247,44 @@ func (s *testSerialSuite) TestUnreasonablyClose(c *C) {
c.Assert(opsAlreadyCoveredMask, Equals, opsNeedsCoveredMask, Commentf("these operators are not covered %s", commentBuf.String()))
}

func (s *testSerialSuite) TestIssue29498(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("DROP TABLE IF EXISTS t1;")
tk.MustExec("CREATE TABLE t1 (t3 TIME(3), d DATE, t TIME);")
tk.MustExec("INSERT INTO t1 VALUES ('00:00:00.567', '2002-01-01', '00:00:02');")

res := tk.MustQuery("SELECT CONCAT(IFNULL(t3, d)) AS col1 FROM t1;")
row := res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp+3+1)
c.Assert(row[len(row)-12:], Equals, "00:00:00.567")

res = tk.MustQuery("SELECT IFNULL(t3, d) AS col1 FROM t1;")
row = res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp+3+1)
c.Assert(row[len(row)-12:], Equals, "00:00:00.567")

res = tk.MustQuery("SELECT CONCAT(IFNULL(t, d)) AS col1 FROM t1;")
row = res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp)
c.Assert(row[len(row)-8:], Equals, "00:00:02")

res = tk.MustQuery("SELECT IFNULL(t, d) AS col1 FROM t1;")
row = res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp)
c.Assert(row[len(row)-8:], Equals, "00:00:02")

res = tk.MustQuery("SELECT CONCAT(xx) FROM (SELECT t3 AS xx FROM t1 UNION SELECT d FROM t1) x ORDER BY -xx LIMIT 1;")
row = res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp+3+1)
c.Assert(row[len(row)-12:], Equals, "00:00:00.567")

res = tk.MustQuery("SELECT CONCAT(CASE WHEN d IS NOT NULL THEN t3 ELSE d END) AS col1 FROM t1;")
row = res.Rows()[0][0].(string)
c.Assert(len(row), Equals, mysql.MaxDatetimeWidthNoFsp+3+1)
c.Assert(row[len(row)-12:], Equals, "00:00:00.567")
}

func (s *testSuite) TestDeleteWithMulTbl(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
21 changes: 19 additions & 2 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,11 +740,28 @@ func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) {
func (s *tiflashTestSuite) TestAvgOverflow(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
// avg int
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a decimal(1,0))")
tk.MustExec("alter table t set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "t")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into t values(9)")
for i := 0; i < 16; i++ {
tk.MustExec("insert into t select * from t")
}
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery("select avg(a) from t group by a").Check(testkit.Rows("9.0000"))
tk.MustExec("drop table if exists t")

// avg decimal
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("alter table td set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "td")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
tb = testGetTableByName(c, tk.Se, "test", "td")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
Expand Down
32 changes: 17 additions & 15 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type baseFuncDesc struct {

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
err := b.TypeInfer(ctx)
return b, err
}

Expand Down Expand Up @@ -83,8 +83,8 @@ func (a *baseFuncDesc) String() string {
return buffer.String()
}

// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
Expand Down Expand Up @@ -206,6 +206,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal
// compatible with mysql.
func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) {
if avgRetType.Tp == mysql.TypeNewDecimal {
a.RetTp.Flen = mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.Flen+22)
}
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
Expand Down Expand Up @@ -245,6 +253,12 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {

a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0
// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}

}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
Expand Down Expand Up @@ -368,18 +382,6 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}

// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
Expand Down
3 changes: 3 additions & 0 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E
argTp := args[0].GetType().EvalType()
switch argTp {
case types.ETInt:
if bf.tp.Flen == types.UnspecifiedLength {
bf.tp.Flen = args[0].GetType().Flen
}
sig = &builtinCastIntAsStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString)
case types.ETReal:
Expand Down
Loading

0 comments on commit 6ebc88d

Please sign in to comment.