Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
Signed-off-by: wjhuang2016 <huangwenjun1997@gmail.com>
  • Loading branch information
wjhuang2016 committed Nov 23, 2021
1 parent b87f9d1 commit c953071
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
38 changes: 22 additions & 16 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/set"
Expand Down Expand Up @@ -317,6 +318,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
iter := chunk.NewIterator4Chunk(srcChk)

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand Down Expand Up @@ -359,7 +361,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])

Expand All @@ -386,7 +388,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
Expand All @@ -409,7 +411,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2])
}
Expand Down Expand Up @@ -447,6 +449,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().Collate)
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
Expand All @@ -467,7 +470,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)

Expand All @@ -488,7 +491,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zero(t, result)
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
Expand All @@ -499,7 +502,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
require.NoError(t, err)

dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
Expand Down Expand Up @@ -570,6 +573,7 @@ func testAggFunc(t *testing.T, p aggTest) {
ctx := mock.NewContext()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand All @@ -596,7 +600,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -606,7 +610,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])

Expand Down Expand Up @@ -639,7 +643,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -649,7 +653,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])
}
Expand All @@ -658,6 +662,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
srcChk := p.genSrcChk()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand Down Expand Up @@ -685,7 +690,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -695,7 +700,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])
}
Expand Down Expand Up @@ -749,6 +754,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().Collate)
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
Expand All @@ -762,7 +768,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -772,7 +778,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])

Expand Down Expand Up @@ -805,7 +811,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -815,7 +821,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
Expand Down
4 changes: 3 additions & 1 deletion store/mockstore/mockcopr/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
)

type aggCtxsMapper map[string][]*aggregation.AggEvaluateContext
Expand Down Expand Up @@ -208,6 +209,7 @@ type streamAggExec struct {
aggExprs []aggregation.Aggregation
aggCtxs []*aggregation.AggEvaluateContext
groupByExprs []expression.Expression
groupByCollators []collate.Collator
relatedColOffsets []int
row []types.Datum
tmpGroupByRow []types.Datum
Expand Down Expand Up @@ -288,7 +290,7 @@ func (e *streamAggExec) meetNewGroup(row [][]byte) (bool, error) {
return false, errors.Trace(err)
}
if matched {
c, err := d.CompareDatum(e.evalCtx.sc, &e.nextGroupByRow[i])
c, err := d.Compare(e.evalCtx.sc, &e.nextGroupByRow[i], e.groupByCollators[i])
if err != nil {
return false, errors.Trace(err)
}
Expand Down
5 changes: 5 additions & 0 deletions store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,17 @@ func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*
for _, agg := range aggs {
aggCtxs = append(aggCtxs, agg.CreateContext(ctx.evalCtx.sc))
}
groupByCollators := make([]collate.Collator, 0, len(groupBys))
for _, expr := range groupBys {
groupByCollators = append(groupByCollators, collate.GetCollator(expr.GetType().Collate))
}

return &streamAggExec{
evalCtx: ctx.evalCtx,
aggExprs: aggs,
aggCtxs: aggCtxs,
groupByExprs: groupBys,
groupByCollators: groupByCollators,
currGroupByValues: make([][]byte, 0),
relatedColOffsets: relatedColOffsets,
row: make([]types.Datum, len(ctx.evalCtx.columnInfos)),
Expand Down

0 comments on commit c953071

Please sign in to comment.