From 4211cb7929926b5d93bf521274f8664557b6c26f Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Mon, 18 May 2020 14:37:44 +0800 Subject: [PATCH] executor: fix memory corrupt in COUNT/JSON_OBJECTAGG/GROUP_CONCAT (#17106) (#17194) --- executor/aggfuncs/aggfunc_test.go | 92 +++++++++++++++++------- executor/aggfuncs/func_count_distinct.go | 2 + executor/aggfuncs/func_group_concat.go | 12 ++-- tablecodec/tablecodec_test.go | 8 +-- types/datum.go | 4 +- util/codec/codec_test.go | 6 +- 6 files changed, 84 insertions(+), 40 deletions(-) diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 72a6f9ce2132c..c12a6ae04c81a 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -72,6 +72,26 @@ type aggTest struct { orderBy bool } +func (p *aggTest) genSrcChk() *chunk.Chunk { + srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) + for i := 0; i < p.numRows; i++ { + dt := p.dataGen(i) + srcChk.AppendDatum(0, &dt) + } + srcChk.AppendDatum(0, &types.Datum{}) + return srcChk +} + +// messUpChunk messes up the chunk for testing memory reference. +func (p *aggTest) messUpChunk(c *chunk.Chunk) { + for i := 0; i < p.numRows; i++ { + raw := c.GetRow(i).GetRaw(0) + for i := range raw { + raw[i] = 255 + } + } +} + type multiArgsAggTest struct { dataTypes []*types.FieldType retType *types.FieldType @@ -82,12 +102,32 @@ type multiArgsAggTest struct { orderBy bool } -func (s *testSuite) testMergePartialResult(c *C, p aggTest) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) +func (p *multiArgsAggTest) genSrcChk() *chunk.Chunk { + srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows) for i := 0; i < p.numRows; i++ { - dt := p.dataGen(i) - srcChk.AppendDatum(0, &dt) + for j := 0; j < len(p.dataGens); j++ { + fdt := p.dataGens[j](i) + srcChk.AppendDatum(j, &fdt) + } + } + srcChk.AppendDatum(0, &types.Datum{}) + return srcChk +} + +// messUpChunk messes up the chunk for testing memory reference. +func (p *multiArgsAggTest) messUpChunk(c *chunk.Chunk) { + for i := 0; i < p.numRows; i++ { + for j := 0; j < len(p.dataGens); j++ { + raw := c.GetRow(i).GetRaw(j) + for i := range raw { + raw[i] = 255 + } + } } +} + +func (s *testSuite) testMergePartialResult(c *C, p aggTest) { + srcChk := p.genSrcChk() iter := chunk.NewIterator4Chunk(srcChk) args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} @@ -116,6 +156,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { for row := iter.Begin(); row != iter.End(); row = iter.Next() { partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) } + p.messUpChunk(srcChk) partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt := resultChk.GetRow(0).GetDatum(0, p.dataType) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) @@ -126,11 +167,14 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { c.Assert(err, IsNil) partialFunc.ResetPartialResult(partialResult) + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) iter.Begin() iter.Next() for row := iter.Next(); row != iter.End(); row = iter.Next() { partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) } + p.messUpChunk(srcChk) resultChk.Reset() partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt = resultChk.GetRow(0).GetDatum(0, p.dataType) @@ -168,13 +212,7 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i } func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) { - srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows) - for i := 0; i < p.numRows; i++ { - for j := 0; j < len(p.dataGens); j++ { - fdt := p.dataGens[j](i) - srcChk.AppendDatum(j, &fdt) - } - } + srcChk := p.genSrcChk() iter := chunk.NewIterator4Chunk(srcChk) args := make([]expression.Expression, len(p.dataTypes)) @@ -204,6 +242,7 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) { for row := iter.Begin(); row != iter.End(); row = iter.Next() { partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) } + p.messUpChunk(srcChk) partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt := resultChk.GetRow(0).GetDatum(0, p.retType) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) @@ -214,11 +253,14 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) { c.Assert(err, IsNil) partialFunc.ResetPartialResult(partialResult) + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) iter.Begin() iter.Next() for row := iter.Next(); row != iter.End(); row = iter.Next() { partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) } + p.messUpChunk(srcChk) resultChk.Reset() partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt = resultChk.GetRow(0).GetDatum(0, p.retType) @@ -300,12 +342,7 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum { } func (s *testSuite) testAggFunc(c *C, p aggTest) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) - for i := 0; i < p.numRows; i++ { - dt := p.dataGen(i) - srcChk.AppendDatum(0, &dt) - } - srcChk.AppendDatum(0, &types.Datum{}) + srcChk := p.genSrcChk() args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} if p.funcName == ast.AggFuncGroupConcat { @@ -326,6 +363,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) @@ -353,13 +391,18 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { finalPr = finalFunc.AllocPartialResult() resultChk.Reset() + srcChk = p.genSrcChk() iter = chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) @@ -377,14 +420,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { } func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { - srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows) - for i := 0; i < p.numRows; i++ { - for j := 0; j < len(p.dataGens); j++ { - fdt := p.dataGens[j](i) - srcChk.AppendDatum(j, &fdt) - } - } - srcChk.AppendDatum(0, &types.Datum{}) + srcChk := p.genSrcChk() args := make([]expression.Expression, len(p.dataTypes)) for k := 0; k < len(p.dataTypes); k++ { @@ -409,6 +445,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) @@ -436,13 +473,18 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { finalPr = finalFunc.AllocPartialResult() resultChk.Reset() + srcChk = p.genSrcChk() iter = chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) } + p.messUpChunk(srcChk) finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) diff --git a/executor/aggfuncs/func_count_distinct.go b/executor/aggfuncs/func_count_distinct.go index 867dde487fb74..50d3750877963 100644 --- a/executor/aggfuncs/func_count_distinct.go +++ b/executor/aggfuncs/func_count_distinct.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/set" + "github.com/pingcap/tidb/util/stringutil" ) type partialResult4CountDistinctInt struct { @@ -254,6 +255,7 @@ func (e *countOriginalWithDistinct4String) UpdatePartialResult(sctx sessionctx.C if p.valSet.Exist(input) { continue } + input = stringutil.Copy(input) p.valSet.Insert(input) } diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index 932d25b901514..8493365297ee8 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -224,7 +224,7 @@ func (e *groupConcatDistinct) GetTruncated() *int32 { type sortRow struct { buffer *bytes.Buffer - byItems []types.Datum + byItems []*types.Datum } type topNRows struct { @@ -245,7 +245,7 @@ func (h topNRows) Len() int { func (h topNRows) Less(i, j int) bool { n := len(h.rows[i].byItems) for k := 0; k < n; k++ { - ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, &h.rows[j].byItems[k]) + ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k]) if err != nil { h.err = err return false @@ -384,14 +384,14 @@ func (e *groupConcatOrder) UpdatePartialResult(sctx sessionctx.Context, rowsInGr } sortRow := sortRow{ buffer: buffer, - byItems: make([]types.Datum, 0, len(e.byItems)), + byItems: make([]*types.Datum, 0, len(e.byItems)), } for _, byItem := range e.byItems { d, err := byItem.Expr.Eval(row) if err != nil { return err } - sortRow.byItems = append(sortRow.byItems, d) + sortRow.byItems = append(sortRow.byItems, d.Copy()) } truncated := p.topN.tryToAdd(sortRow) if p.topN.err != nil { @@ -493,14 +493,14 @@ func (e *groupConcatDistinctOrder) UpdatePartialResult(sctx sessionctx.Context, p.valSet.Insert(joinedVal) sortRow := sortRow{ buffer: buffer, - byItems: make([]types.Datum, 0, len(e.byItems)), + byItems: make([]*types.Datum, 0, len(e.byItems)), } for _, byItem := range e.byItems { d, err := byItem.Expr.Eval(row) if err != nil { return err } - sortRow.byItems = append(sortRow.byItems, d) + sortRow.byItems = append(sortRow.byItems, d.Copy()) } truncated := p.topN.tryToAdd(sortRow) if p.topN.err != nil { diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index ad96640fcb2eb..c4296878cfdb9 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -73,8 +73,8 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) { row[0] = types.NewIntDatum(100) row[1] = types.NewBytesDatum([]byte("abc")) row[2] = types.NewDecimalDatum(types.NewDecFromInt(1)) - row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 0}) - row[4] = types.NewDatum(types.Set{Name: "a", Value: 0}) + row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1}) + row[4] = types.NewDatum(types.Set{Name: "a", Value: 1}) row[5] = types.NewDatum(types.BinaryLiteral{100}) // Encode colIDs := make([]int64, 0, len(row)) @@ -101,11 +101,11 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) { c.Assert(ok, IsTrue) equal, err1 := v.CompareDatum(sc, &row[i]) c.Assert(err1, IsNil) - c.Assert(equal, Equals, 0) + c.Assert(equal, Equals, 0, Commentf("expect: %v, got %v", row[i], v)) } // colMap may contains more columns than encoded row. - colMap[4] = types.NewFieldType(mysql.TypeFloat) + // colMap[4] = types.NewFieldType(mysql.TypeFloat) r, err = DecodeRow(bs, colMap, time.UTC) c.Assert(err, IsNil) c.Assert(r, NotNil) diff --git a/types/datum.go b/types/datum.go index 4fa5b4a608775..0c12e841f70ef 100644 --- a/types/datum.go +++ b/types/datum.go @@ -666,7 +666,7 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration) func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) { switch d.k { - case KindString, KindBytes: + case KindString, KindBytes, KindMysqlEnum, KindMysqlSet: return CompareString(d.GetString(), enum.String()), nil default: return d.compareFloat64(sc, enum.ToNumber()) @@ -691,7 +691,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set) (int, error) { switch d.k { - case KindString, KindBytes: + case KindString, KindBytes, KindMysqlEnum, KindMysqlSet: return CompareString(d.GetString(), set.String()), nil default: return d.compareFloat64(sc, set.ToNumber()) diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 90c84869ad18e..66da69f680b0b 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -918,7 +918,7 @@ func (s *testCodecSuite) TestDecodeOneToChunk(c *C) { } else { cmp, err := got.CompareDatum(sc, &expect) c.Assert(err, IsNil) - c.Assert(cmp, Equals, 0) + c.Assert(cmp, Equals, 0, Commentf("expect: %v, got %v", expect, got)) } } } @@ -954,8 +954,8 @@ func datumsForTest(sc *stmtctx.StatementContext) ([]types.Datum, []*types.FieldT Type: mysql.TypeTimestamp, }, types.NewFieldType(mysql.TypeTimestamp)}, {types.Duration{Duration: time.Second, Fsp: 1}, types.NewFieldType(mysql.TypeDuration)}, - {types.Enum{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}}, - {types.Set{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}}, + {types.Enum{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}}, + {types.Set{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}}, {types.BinaryLiteral{100}, &types.FieldType{Tp: mysql.TypeBit, Flen: 8}}, {json.CreateBinary("abc"), types.NewFieldType(mysql.TypeJSON)}, {int64(1), types.NewFieldType(mysql.TypeYear)},