Skip to content

Commit

Permalink
executor: add an unit test case for unreasonable invoking Close (#30696)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored Dec 16, 2021
1 parent bb8774b commit 8cf847a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 6 deletions.
17 changes: 17 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ type MockPhysicalPlan interface {
GetExecutor() Executor
}

// MockExecutorBuilder is a wrapper for executorBuilder.
// ONLY used in test.
type MockExecutorBuilder struct {
*executorBuilder
}

// NewMockExecutorBuilderForTest is ONLY used in test.
func NewMockExecutorBuilderForTest(ctx sessionctx.Context, is infoschema.InfoSchema, ti *TelemetryInfo, snapshotTS uint64, isStaleness bool, replicaReadScope string) *MockExecutorBuilder {
return &MockExecutorBuilder{
executorBuilder: newExecutorBuilder(ctx, is, ti, snapshotTS, isStaleness, replicaReadScope)}
}

// Build builds an executor tree according to `p`.
func (b *MockExecutorBuilder) Build(p plannercore.Plan) Executor {
return b.build(p)
}

func (b *executorBuilder) build(p plannercore.Plan) Executor {
switch v := p.(type) {
case nil:
Expand Down
6 changes: 4 additions & 2 deletions executor/cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ func (e *CTEExec) Close() (err error) {
}
// `iterInTbl` and `resTbl` are shared by multiple operators,
// so will be closed when the SQL finishes.
if err = e.iterOutTbl.DerefAndClose(); err != nil {
return err
if e.iterOutTbl != nil {
if err = e.iterOutTbl.DerefAndClose(); err != nil {
return err
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,14 @@ type IndexReaderExecutor struct {
}

// Close clears all resources hold by current object.
func (e *IndexReaderExecutor) Close() error {
func (e *IndexReaderExecutor) Close() (err error) {
if e.table != nil && e.table.Meta().TempTableType != model.TempTableNone {
return nil
}

err := e.result.Close()
if e.result != nil {
err = e.result.Close()
}
e.result = nil
e.ctx.StoreQueryFeedback(e.feedback)
return err
Expand Down
130 changes: 130 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -80,6 +82,7 @@ import (
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/rowcodec"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tidb/util/timeutil"
"github.com/pingcap/tipb/go-tipb"
Expand Down Expand Up @@ -9501,3 +9504,130 @@ func (s *testSerialSuite) TestIssue30289(c *C) {
err := tk.QueryToErr("select /*+ hash_join(t1) */ * from t t1 join t t2 on t1.a=t2.a")
c.Assert(err.Error(), Matches, "issue30289 build return error")
}

// Test invoke Close without invoking Open before for each operators.
func (s *testSerialSuite) TestUnreasonablyClose(c *C) {
defer testleak.AfterTest(c)()

is := infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable(), plannercore.MockUnsignedTable()})
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
_, err = se.Execute(context.Background(), "use test")
c.Assert(err, IsNil)
// To enable the shuffleExec operator.
_, err = se.Execute(context.Background(), "set @@tidb_merge_join_concurrency=4")
c.Assert(err, IsNil)

var opsNeedsCovered = []plannercore.PhysicalPlan{
&plannercore.PhysicalHashJoin{},
&plannercore.PhysicalMergeJoin{},
&plannercore.PhysicalIndexJoin{},
&plannercore.PhysicalIndexHashJoin{},
&plannercore.PhysicalTableReader{},
&plannercore.PhysicalIndexReader{},
&plannercore.PhysicalIndexLookUpReader{},
&plannercore.PhysicalIndexMergeReader{},
&plannercore.PhysicalApply{},
&plannercore.PhysicalHashAgg{},
&plannercore.PhysicalStreamAgg{},
&plannercore.PhysicalLimit{},
&plannercore.PhysicalSort{},
&plannercore.PhysicalTopN{},
&plannercore.PhysicalCTE{},
&plannercore.PhysicalCTETable{},
&plannercore.PhysicalMaxOneRow{},
&plannercore.PhysicalProjection{},
&plannercore.PhysicalSelection{},
&plannercore.PhysicalTableDual{},
&plannercore.PhysicalWindow{},
&plannercore.PhysicalShuffle{},
&plannercore.PhysicalUnionAll{},
}
executorBuilder := executor.NewMockExecutorBuilderForTest(se, is, nil, math.MaxUint64, false, "global")

var opsNeedsCoveredMask uint64 = 1<<len(opsNeedsCovered) - 1
opsAlreadyCoveredMask := uint64(0)
for i, tc := range []string{
"select /*+ hash_join(t1)*/ * from t t1 join t t2 on t1.a = t2.a",
"select /*+ merge_join(t1)*/ * from t t1 join t t2 on t1.f = t2.f",
"select t.f from t use index(f)",
"select /*+ inl_join(t1) */ * from t t1 join t t2 on t1.f=t2.f",
"select /*+ inl_hash_join(t1) */ * from t t1 join t t2 on t1.f=t2.f",
"SELECT count(1) FROM (SELECT (SELECT min(a) FROM t as t2 WHERE t2.a > t1.a) AS a from t as t1) t",
"select /*+ hash_agg() */ count(f) from t group by a",
"select /*+ stream_agg() */ count(f) from t group by a",
"select * from t order by a, f",
"select * from t order by a, f limit 1",
"select * from t limit 1",
"select (select t1.a from t t1 where t1.a > t2.a) as a from t t2;",
"select a + 1 from t",
"select count(*) a from t having a > 1",
"select * from t where a = 1.1",
"with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 0) select * from cte1",
"select /*+use_index_merge(t, c_d_e, f)*/ * from t where c < 1 or f > 2",
"select sum(f) over (partition by f) from t",
"select /*+ merge_join(t1)*/ * from t t1 join t t2 on t1.d = t2.d",
"select a from t union all select a from t",
} {
comment := Commentf("case:%v sql:%s", i, tc)
c.Assert(err, IsNil, comment)
stmt, err := s.ParseOneStmt(tc, "", "")
c.Assert(err, IsNil, comment)

err = se.NewTxn(context.Background())
c.Assert(err, IsNil, comment)
p, _, err := planner.Optimize(context.TODO(), se, stmt, is)
c.Assert(err, IsNil, comment)
// This for loop level traverses the plan tree to get which operators are covered.
for child := []plannercore.PhysicalPlan{p.(plannercore.PhysicalPlan)}; len(child) != 0; {
newChild := make([]plannercore.PhysicalPlan, 0, len(child))
for _, ch := range child {
found := false
for k, t := range opsNeedsCovered {
if reflect.TypeOf(t) == reflect.TypeOf(ch) {
opsAlreadyCoveredMask |= 1 << k
found = true
break
}
}
c.Assert(found, IsTrue, Commentf("case: %v sql: %s operator %v is not registered in opsNeedsCoveredMask", i, tc, reflect.TypeOf(ch)))
switch x := ch.(type) {
case *plannercore.PhysicalCTE:
newChild = append(newChild, x.RecurPlan)
newChild = append(newChild, x.SeedPlan)
continue
case *plannercore.PhysicalShuffle:
newChild = append(newChild, x.DataSources...)
newChild = append(newChild, x.Tails...)
continue
}
newChild = append(newChild, ch.Children()...)
}
child = newChild
}

e := executorBuilder.Build(p)

func() {
defer func() {
r := recover()
buf := make([]byte, 4096)
stackSize := runtime.Stack(buf, false)
buf = buf[:stackSize]
c.Assert(r, IsNil, Commentf("case: %v\n sql: %s\n error stack: %v", i, tc, string(buf)))
}()
c.Assert(e.Close(), IsNil, comment)
}()
}
// The following code is used to make sure all the operators registered
// in opsNeedsCoveredMask are covered.
commentBuf := strings.Builder{}
if opsAlreadyCoveredMask != opsNeedsCoveredMask {
for i := range opsNeedsCovered {
if opsAlreadyCoveredMask&(1<<i) != 1<<i {
commentBuf.WriteString(fmt.Sprintf(" %v", reflect.TypeOf(opsNeedsCovered[i])))
}
}
}
c.Assert(opsAlreadyCoveredMask, Equals, opsNeedsCoveredMask, Commentf("these operators are not covered %s", commentBuf.String()))
}
9 changes: 7 additions & 2 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ type hashjoinWorkerResult struct {

// Close implements the Executor Close interface.
func (e *HashJoinExec) Close() error {
close(e.closeCh)
if e.closeCh != nil {
close(e.closeCh)
}
e.finished.Store(true)
if e.prepared {
if e.buildFinished != nil {
Expand Down Expand Up @@ -156,6 +158,9 @@ func (e *HashJoinExec) Close() error {

// Open implements the Executor Open interface.
func (e *HashJoinExec) Open(ctx context.Context) error {
if err := e.baseExecutor.Open(ctx); err != nil {
return err
}
e.prepared = false
e.memTracker = memory.NewTracker(e.id, -1)
e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker)
Expand All @@ -179,7 +184,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error {
}
e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats)
}
return e.baseExecutor.Open(ctx)
return nil
}

// fetchProbeSideChunks get chunks from fetches chunks from the big table in a background goroutine
Expand Down
5 changes: 5 additions & 0 deletions executor/merge_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type MergeJoinExec struct {
}

type mergeJoinTable struct {
inited bool
isInner bool
childIndex int
joinKeys []*expression.Column
Expand Down Expand Up @@ -108,10 +109,14 @@ func (t *mergeJoinTable) init(exec *MergeJoinExec) {
}

t.memTracker.AttachTo(exec.memTracker)
t.inited = true
t.memTracker.Consume(t.childChunk.MemoryUsage())
}

func (t *mergeJoinTable) finish() error {
if !t.inited {
return nil
}
t.memTracker.Consume(-t.childChunk.MemoryUsage())

if t.isInner {
Expand Down

0 comments on commit 8cf847a

Please sign in to comment.