Skip to content

Commit

Permalink
executor: fix memory corrupt in COUNT/JSON_OBJECTAGG/GROUP_CONCAT (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
SunRunAway authored May 18, 2020
1 parent f718bed commit 4211cb7
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 40 deletions.
92 changes: 67 additions & 25 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ type aggTest struct {
orderBy bool
}

func (p *aggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}

// messUpChunk messes up the chunk for testing memory reference.
func (p *aggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
raw := c.GetRow(i).GetRaw(0)
for i := range raw {
raw[i] = 255
}
}
}

type multiArgsAggTest struct {
dataTypes []*types.FieldType
retType *types.FieldType
Expand All @@ -82,12 +102,32 @@ type multiArgsAggTest struct {
orderBy bool
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
func (p *multiArgsAggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}

// messUpChunk messes up the chunk for testing memory reference.
func (p *multiArgsAggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
raw := c.GetRow(i).GetRaw(j)
for i := range raw {
raw[i] = 255
}
}
}
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
Expand Down Expand Up @@ -116,6 +156,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.dataType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
Expand All @@ -126,11 +167,14 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
Expand Down Expand Up @@ -168,13 +212,7 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i
}

func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)

args := make([]expression.Expression, len(p.dataTypes))
Expand Down Expand Up @@ -204,6 +242,7 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
Expand All @@ -214,11 +253,14 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
Expand Down Expand Up @@ -300,12 +342,7 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
}

func (s *testSuite) testAggFunc(c *C, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
srcChk := p.genSrcChk()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
if p.funcName == ast.AggFuncGroupConcat {
Expand All @@ -326,6 +363,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down Expand Up @@ -353,13 +391,18 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
finalPr = finalFunc.AllocPartialResult()

resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand All @@ -377,14 +420,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
}

func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
srcChk := p.genSrcChk()

args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
Expand All @@ -409,6 +445,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down Expand Up @@ -436,13 +473,18 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
finalPr = finalFunc.AllocPartialResult()

resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down
2 changes: 2 additions & 0 deletions executor/aggfuncs/func_count_distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/set"
"github.com/pingcap/tidb/util/stringutil"
)

type partialResult4CountDistinctInt struct {
Expand Down Expand Up @@ -254,6 +255,7 @@ func (e *countOriginalWithDistinct4String) UpdatePartialResult(sctx sessionctx.C
if p.valSet.Exist(input) {
continue
}
input = stringutil.Copy(input)
p.valSet.Insert(input)
}

Expand Down
12 changes: 6 additions & 6 deletions executor/aggfuncs/func_group_concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (e *groupConcatDistinct) GetTruncated() *int32 {

type sortRow struct {
buffer *bytes.Buffer
byItems []types.Datum
byItems []*types.Datum
}

type topNRows struct {
Expand All @@ -245,7 +245,7 @@ func (h topNRows) Len() int {
func (h topNRows) Less(i, j int) bool {
n := len(h.rows[i].byItems)
for k := 0; k < n; k++ {
ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, &h.rows[j].byItems[k])
ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k])
if err != nil {
h.err = err
return false
Expand Down Expand Up @@ -384,14 +384,14 @@ func (e *groupConcatOrder) UpdatePartialResult(sctx sessionctx.Context, rowsInGr
}
sortRow := sortRow{
buffer: buffer,
byItems: make([]types.Datum, 0, len(e.byItems)),
byItems: make([]*types.Datum, 0, len(e.byItems)),
}
for _, byItem := range e.byItems {
d, err := byItem.Expr.Eval(row)
if err != nil {
return err
}
sortRow.byItems = append(sortRow.byItems, d)
sortRow.byItems = append(sortRow.byItems, d.Copy())
}
truncated := p.topN.tryToAdd(sortRow)
if p.topN.err != nil {
Expand Down Expand Up @@ -493,14 +493,14 @@ func (e *groupConcatDistinctOrder) UpdatePartialResult(sctx sessionctx.Context,
p.valSet.Insert(joinedVal)
sortRow := sortRow{
buffer: buffer,
byItems: make([]types.Datum, 0, len(e.byItems)),
byItems: make([]*types.Datum, 0, len(e.byItems)),
}
for _, byItem := range e.byItems {
d, err := byItem.Expr.Eval(row)
if err != nil {
return err
}
sortRow.byItems = append(sortRow.byItems, d)
sortRow.byItems = append(sortRow.byItems, d.Copy())
}
truncated := p.topN.tryToAdd(sortRow)
if p.topN.err != nil {
Expand Down
8 changes: 4 additions & 4 deletions tablecodec/tablecodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) {
row[0] = types.NewIntDatum(100)
row[1] = types.NewBytesDatum([]byte("abc"))
row[2] = types.NewDecimalDatum(types.NewDecFromInt(1))
row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 0})
row[4] = types.NewDatum(types.Set{Name: "a", Value: 0})
row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})
row[4] = types.NewDatum(types.Set{Name: "a", Value: 1})
row[5] = types.NewDatum(types.BinaryLiteral{100})
// Encode
colIDs := make([]int64, 0, len(row))
Expand All @@ -101,11 +101,11 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) {
c.Assert(ok, IsTrue)
equal, err1 := v.CompareDatum(sc, &row[i])
c.Assert(err1, IsNil)
c.Assert(equal, Equals, 0)
c.Assert(equal, Equals, 0, Commentf("expect: %v, got %v", row[i], v))
}

// colMap may contains more columns than encoded row.
colMap[4] = types.NewFieldType(mysql.TypeFloat)
// colMap[4] = types.NewFieldType(mysql.TypeFloat)
r, err = DecodeRow(bs, colMap, time.UTC)
c.Assert(err, IsNil)
c.Assert(r, NotNil)
Expand Down
4 changes: 2 additions & 2 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration)

func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) {
switch d.k {
case KindString, KindBytes:
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), enum.String()), nil
default:
return d.compareFloat64(sc, enum.ToNumber())
Expand All @@ -691,7 +691,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter

func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set) (int, error) {
switch d.k {
case KindString, KindBytes:
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), set.String()), nil
default:
return d.compareFloat64(sc, set.ToNumber())
Expand Down
6 changes: 3 additions & 3 deletions util/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ func (s *testCodecSuite) TestDecodeOneToChunk(c *C) {
} else {
cmp, err := got.CompareDatum(sc, &expect)
c.Assert(err, IsNil)
c.Assert(cmp, Equals, 0)
c.Assert(cmp, Equals, 0, Commentf("expect: %v, got %v", expect, got))
}
}
}
Expand Down Expand Up @@ -954,8 +954,8 @@ func datumsForTest(sc *stmtctx.StatementContext) ([]types.Datum, []*types.FieldT
Type: mysql.TypeTimestamp,
}, types.NewFieldType(mysql.TypeTimestamp)},
{types.Duration{Duration: time.Second, Fsp: 1}, types.NewFieldType(mysql.TypeDuration)},
{types.Enum{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}},
{types.Set{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}},
{types.Enum{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}},
{types.Set{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}},
{types.BinaryLiteral{100}, &types.FieldType{Tp: mysql.TypeBit, Flen: 8}},
{json.CreateBinary("abc"), types.NewFieldType(mysql.TypeJSON)},
{int64(1), types.NewFieldType(mysql.TypeYear)},
Expand Down

0 comments on commit 4211cb7

Please sign in to comment.