From c953071436f058e53304f9078888d74a4218346e Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 23 Nov 2021 14:23:18 +0800 Subject: [PATCH] done Signed-off-by: wjhuang2016 --- executor/aggfuncs/aggfunc_test.go | 38 ++++++++++++--------- store/mockstore/mockcopr/aggregate.go | 4 ++- store/mockstore/mockcopr/cop_handler_dag.go | 5 +++ 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 55e6c951d7cdc..d106f6b824e62 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/set" @@ -317,6 +318,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { iter := chunk.NewIterator4Chunk(srcChk) args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + ctor := collate.GetCollator(p.dataType.Collate) if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) } @@ -359,7 +361,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) @@ -386,7 +388,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) @@ -409,7 +411,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2]) } @@ -447,6 +449,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul {Expr: args[0], Desc: true}, } } + ctor := collate.GetCollator(args[0].GetType().Collate) partialDesc, finalDesc := desc.Split([]int{0, 1}) // build partial func for partial phase. @@ -467,7 +470,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, p.retType) - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Zero(t, result) @@ -488,7 +491,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, p.retType) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Zero(t, result) _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) @@ -499,7 +502,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, p.retType) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor) require.NoError(t, err) require.Zero(t, result) } @@ -570,6 +573,7 @@ func testAggFunc(t *testing.T, p aggTest) { ctx := mock.NewContext() args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + ctor := collate.GetCollator(p.dataType.Collate) if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) } @@ -596,7 +600,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) @@ -606,7 +610,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) @@ -639,7 +643,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) @@ -649,7 +653,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) } @@ -658,6 +662,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) { srcChk := p.genSrcChk() args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + ctor := collate.GetCollator(p.dataType.Collate) if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) } @@ -685,7 +690,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -695,7 +700,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[0]) } @@ -749,6 +754,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe {Expr: args[0], Desc: true}, } } + ctor := collate.GetCollator(args[0].GetType().Collate) finalFunc := aggfuncs.Build(ctx, desc, 0) finalPr, _ := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) @@ -762,7 +768,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -772,7 +778,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[0]) @@ -805,7 +811,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -815,7 +821,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) require.NoError(t, err) require.Zero(t, result) } diff --git a/store/mockstore/mockcopr/aggregate.go b/store/mockstore/mockcopr/aggregate.go index 0a0dd6d6cfb56..1e7bcc1a9207a 100644 --- a/store/mockstore/mockcopr/aggregate.go +++ b/store/mockstore/mockcopr/aggregate.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" ) type aggCtxsMapper map[string][]*aggregation.AggEvaluateContext @@ -208,6 +209,7 @@ type streamAggExec struct { aggExprs []aggregation.Aggregation aggCtxs []*aggregation.AggEvaluateContext groupByExprs []expression.Expression + groupByCollators []collate.Collator relatedColOffsets []int row []types.Datum tmpGroupByRow []types.Datum @@ -288,7 +290,7 @@ func (e *streamAggExec) meetNewGroup(row [][]byte) (bool, error) { return false, errors.Trace(err) } if matched { - c, err := d.CompareDatum(e.evalCtx.sc, &e.nextGroupByRow[i]) + c, err := d.Compare(e.evalCtx.sc, &e.nextGroupByRow[i], e.groupByCollators[i]) if err != nil { return false, errors.Trace(err) } diff --git a/store/mockstore/mockcopr/cop_handler_dag.go b/store/mockstore/mockcopr/cop_handler_dag.go index 4d71a1caa3994..2fd9d8f73d4b3 100644 --- a/store/mockstore/mockcopr/cop_handler_dag.go +++ b/store/mockstore/mockcopr/cop_handler_dag.go @@ -396,12 +396,17 @@ func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (* for _, agg := range aggs { aggCtxs = append(aggCtxs, agg.CreateContext(ctx.evalCtx.sc)) } + groupByCollators := make([]collate.Collator, 0, len(groupBys)) + for _, expr := range groupBys { + groupByCollators = append(groupByCollators, collate.GetCollator(expr.GetType().Collate)) + } return &streamAggExec{ evalCtx: ctx.evalCtx, aggExprs: aggs, aggCtxs: aggCtxs, groupByExprs: groupBys, + groupByCollators: groupByCollators, currGroupByValues: make([][]byte, 0), relatedColOffsets: relatedColOffsets, row: make([]types.Datum, len(ctx.evalCtx.columnInfos)),