Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix memory corrupt in COUNT/JSON_OBJECTAGG/GROUP_CONCAT #17106

Merged
merged 8 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 67 additions & 25 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ type aggTest struct {
orderBy bool
}

func (p *aggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}

// messUpChunk messes up the chunk for testing memory reference.
func (p *aggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
raw := c.Column(0).GetRaw(i)
for i := range raw {
raw[i] = 255
}
}
}

type multiArgsAggTest struct {
dataTypes []*types.FieldType
retType *types.FieldType
Expand All @@ -82,12 +102,32 @@ type multiArgsAggTest struct {
orderBy bool
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
func (p *multiArgsAggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}

// messUpChunk messes up the chunk for testing memory reference.
func (p *multiArgsAggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
raw := c.Column(j).GetRaw(i)
for i := range raw {
raw[i] = 255
}
}
}
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
Expand Down Expand Up @@ -116,6 +156,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.dataType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
Expand All @@ -126,11 +167,14 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
Expand Down Expand Up @@ -168,13 +212,7 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i
}

func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)

args := make([]expression.Expression, len(p.dataTypes))
Expand Down Expand Up @@ -204,6 +242,7 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
Expand All @@ -214,11 +253,14 @@ func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
Expand Down Expand Up @@ -300,12 +342,7 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
}

func (s *testSuite) testAggFunc(c *C, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
srcChk := p.genSrcChk()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
if p.funcName == ast.AggFuncGroupConcat {
Expand All @@ -326,6 +363,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down Expand Up @@ -353,13 +391,18 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
finalPr = finalFunc.AllocPartialResult()

resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand All @@ -377,14 +420,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
}

func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
srcChk := p.genSrcChk()

args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
Expand All @@ -409,6 +445,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down Expand Up @@ -436,13 +473,18 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
finalPr = finalFunc.AllocPartialResult()

resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
Expand Down
2 changes: 2 additions & 0 deletions executor/aggfuncs/func_count_distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/set"
"github.com/pingcap/tidb/util/stringutil"
)

type partialResult4CountDistinctInt struct {
Expand Down Expand Up @@ -254,6 +255,7 @@ func (e *countOriginalWithDistinct4String) UpdatePartialResult(sctx sessionctx.C
if p.valSet.Exist(input) {
continue
}
input = stringutil.Copy(input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about modifying the stringmap?

Copy link
Member

@zz-jason zz-jason May 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[jianzhang.zj:~/Code/tidb] git:(master ✔)
➜ find . -name "*.go" | xargs grep "set.StringSet" | grep -v "test" | grep -v "(" | grep -v "{" | grep "set.StringSet"
./planner/core/planbuilder.go:  underlyingViewNames set.StringSet
./planner/core/memtable_predicate_extractor.go: result set.StringSet,
./planner/core/memtable_predicate_extractor.go: excludeCols set.StringSet,
./planner/core/memtable_predicate_extractor.go:         var values set.StringSet
./planner/core/memtable_predicate_extractor.go: NodeTypes set.StringSet
./planner/core/memtable_predicate_extractor.go: Instances set.StringSet
./planner/core/memtable_predicate_extractor.go: NodeTypes set.StringSet
./planner/core/memtable_predicate_extractor.go: Instances set.StringSet
./planner/core/memtable_predicate_extractor.go: LogLevels set.StringSet
./planner/core/memtable_predicate_extractor.go: LabelConditions map[string]set.StringSet
./planner/core/memtable_predicate_extractor.go: MetricsNames set.StringSet
./planner/core/memtable_predicate_extractor.go: Rules set.StringSet
./planner/core/memtable_predicate_extractor.go: Items set.StringSet
./planner/core/memtable_predicate_extractor.go: Rules       set.StringSet
./planner/core/memtable_predicate_extractor.go: MetricNames set.StringSet
./planner/core/memtable_predicate_extractor.go: Types       set.StringSet
./planner/core/util.go:// extractStringFromStringSet helps extract string info from set.StringSet
./executor/aggregate.go:        groupSet            set.StringSet
./executor/aggregate.go:        groupSet         set.StringSet
./executor/aggfuncs/func_count_distinct.go:     valSet set.StringSet
./executor/aggfuncs/func_count_distinct.go:     valSet set.StringSet
./executor/aggfuncs/func_count_distinct.go:     valSet set.StringSet
./executor/aggfuncs/func_group_concat.go:       valSet            set.StringSet
./executor/aggfuncs/func_group_concat.go:       valSet            set.StringSet
./executor/aggfuncs/func_avg.go:        valSet set.StringSet
./executor/aggfuncs/func_sum.go:        valSet set.StringSet
./executor/inspection_result.go:                set       set.StringSet
./expression/builtin_other.go:  hashSet set.StringSet
./expression/builtin_other.go:  hashSet set.StringSet

There are lots of places using StringSet, would they all have the same potential issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about modifying the stringmap?

I'd rather not. The users of stringmap should have control of how they handle the memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are lots of places using StringSet, would they all have the same potential issue?

This PR/issue aims to resolve the problem in package executor/aggfuncs, I'll take another look and create issues if there'are further problems and reply here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others have no problem with my glance. And some comments are added.

p.valSet.Insert(input)
}

Expand Down
12 changes: 6 additions & 6 deletions executor/aggfuncs/func_group_concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (e *groupConcatDistinct) GetTruncated() *int32 {

type sortRow struct {
buffer *bytes.Buffer
byItems []types.Datum
byItems []*types.Datum
}

type topNRows struct {
Expand All @@ -245,7 +245,7 @@ func (h topNRows) Len() int {
func (h topNRows) Less(i, j int) bool {
n := len(h.rows[i].byItems)
for k := 0; k < n; k++ {
ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, &h.rows[j].byItems[k])
ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k])
if err != nil {
h.err = err
return false
Expand Down Expand Up @@ -384,14 +384,14 @@ func (e *groupConcatOrder) UpdatePartialResult(sctx sessionctx.Context, rowsInGr
}
sortRow := sortRow{
buffer: buffer,
byItems: make([]types.Datum, 0, len(e.byItems)),
byItems: make([]*types.Datum, 0, len(e.byItems)),
}
for _, byItem := range e.byItems {
d, err := byItem.Expr.Eval(row)
if err != nil {
return err
}
sortRow.byItems = append(sortRow.byItems, d)
sortRow.byItems = append(sortRow.byItems, d.Clone())
}
truncated := p.topN.tryToAdd(sortRow)
if p.topN.err != nil {
Expand Down Expand Up @@ -493,14 +493,14 @@ func (e *groupConcatDistinctOrder) UpdatePartialResult(sctx sessionctx.Context,
p.valSet.Insert(joinedVal)
sortRow := sortRow{
buffer: buffer,
byItems: make([]types.Datum, 0, len(e.byItems)),
byItems: make([]*types.Datum, 0, len(e.byItems)),
}
for _, byItem := range e.byItems {
d, err := byItem.Expr.Eval(row)
if err != nil {
return err
}
sortRow.byItems = append(sortRow.byItems, d)
sortRow.byItems = append(sortRow.byItems, d.Clone())
}
truncated := p.topN.tryToAdd(sortRow)
if p.topN.err != nil {
Expand Down
4 changes: 3 additions & 1 deletion executor/aggfuncs/func_json_objectagg.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/stringutil"
)

type jsonObjectAgg struct {
Expand Down Expand Up @@ -91,8 +92,9 @@ func (e *jsonObjectAgg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
if err != nil {
return errors.Trace(err)
}
keyString = stringutil.Copy(keyString)

realVal := value.GetValue()
realVal := value.Clone().GetValue()
switch x := realVal.(type) {
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, *types.MyDecimal, []uint8, types.Time, types.Duration:
p.entries[keyString] = realVal
Expand Down
2 changes: 1 addition & 1 deletion executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
}

for j := 0; j < e.childResult.NumRows(); j++ {
groupKey := string(e.groupKeyBuffer[j])
groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused.
if !e.groupSet.Exist(groupKey) {
e.groupSet.Insert(groupKey)
e.groupKeys = append(e.groupKeys, groupKey)
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er
b.hasNull = true
continue
}
b.hashSet.Insert(string(collator.Key(val)))
b.hashSet.Insert(string(collator.Key(val))) // should do memory copy here
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/memtable_predicate_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (helper extractHelper) extractCol(
continue
}
var colName string
var datums []types.Datum
var datums []types.Datum // the memory of datums should not be reused, they will be put into result.
switch fn.FuncName.L {
case ast.EQ:
colName, datums = helper.extractColBinaryOpConsExpr(extractCols, fn)
Expand Down
6 changes: 3 additions & 3 deletions tablecodec/tablecodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) {
row[0] = types.NewIntDatum(100)
row[1] = types.NewBytesDatum([]byte("abc"))
row[2] = types.NewDecimalDatum(types.NewDecFromInt(1))
row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 0})
row[4] = types.NewDatum(types.Set{Name: "a", Value: 0})
row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})
row[4] = types.NewDatum(types.Set{Name: "a", Value: 1})
row[5] = types.NewDatum(types.BinaryLiteral{100})
// Encode
colIDs := make([]int64, 0, len(row))
Expand Down Expand Up @@ -104,7 +104,7 @@ func (s *testTableCodecSuite) TestRowCodec(c *C) {
c.Assert(ok, IsTrue)
equal, err1 := v.CompareDatum(sc, &row[i])
c.Assert(err1, IsNil)
c.Assert(equal, Equals, 0)
c.Assert(equal, Equals, 0, Commentf("expect: %v, got %v", row[i], v))
}

// colMap may contains more columns than encoded row.
Expand Down
4 changes: 2 additions & 2 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration)

func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) {
switch d.k {
case KindString, KindBytes:
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), enum.String(), d.collation), nil
default:
return d.compareFloat64(sc, enum.ToNumber())
Expand All @@ -747,7 +747,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter

func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set) (int, error) {
switch d.k {
case KindString, KindBytes:
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), set.String(), d.collation), nil
default:
return d.compareFloat64(sc, set.ToNumber())
Expand Down
8 changes: 4 additions & 4 deletions util/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,9 @@ func (s *testCodecSuite) TestDecodeOneToChunk(c *C) {
if got.Kind() != types.KindMysqlDecimal {
cmp, err := got.CompareDatum(sc, &expect)
c.Assert(err, IsNil)
c.Assert(cmp, Equals, 0)
c.Assert(cmp, Equals, 0, Commentf("expect: %v, got %v", expect, got))
} else {
c.Assert(got.GetString(), Equals, expect.GetString())
c.Assert(got.GetString(), Equals, expect.GetString(), Commentf("expect: %v, got %v", expect, got))
}
}
}
Expand Down Expand Up @@ -1032,8 +1032,8 @@ func datumsForTest(sc *stmtctx.StatementContext) ([]types.Datum, []*types.FieldT
{types.CurrentTime(mysql.TypeDate), types.NewFieldType(mysql.TypeDate)},
{types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, types.DefaultFsp), types.NewFieldType(mysql.TypeTimestamp)},
{types.Duration{Duration: time.Second, Fsp: 1}, types.NewFieldType(mysql.TypeDuration)},
{types.Enum{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}},
{types.Set{Name: "a", Value: 0}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}},
{types.Enum{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}},
{types.Set{Name: "a", Value: 1}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}},
{types.BinaryLiteral{100}, &types.FieldType{Tp: mysql.TypeBit, Flen: 8}},
{json.CreateBinary("abc"), types.NewFieldType(mysql.TypeJSON)},
{int64(1), types.NewFieldType(mysql.TypeYear)},
Expand Down