diff --git a/executor/builder.go b/executor/builder.go index 85eb1e63ce26f..81e48e71b5fa3 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -120,6 +120,23 @@ type MockPhysicalPlan interface { GetExecutor() Executor } +// MockExecutorBuilder is a wrapper for executorBuilder. +// ONLY used in test. +type MockExecutorBuilder struct { + *executorBuilder +} + +// NewMockExecutorBuilderForTest is ONLY used in test. +func NewMockExecutorBuilderForTest(ctx sessionctx.Context, is infoschema.InfoSchema, ti *TelemetryInfo, snapshotTS uint64, isStaleness bool, replicaReadScope string) *MockExecutorBuilder { + return &MockExecutorBuilder{ + executorBuilder: newExecutorBuilder(ctx, is, ti, snapshotTS, isStaleness, replicaReadScope)} +} + +// Build builds an executor tree according to `p`. +func (b *MockExecutorBuilder) Build(p plannercore.Plan) Executor { + return b.build(p) +} + func (b *executorBuilder) build(p plannercore.Plan) Executor { switch v := p.(type) { case nil: diff --git a/executor/cte.go b/executor/cte.go index 4fe3414b97154..a432fb15fe6ba 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -203,8 +203,10 @@ func (e *CTEExec) Close() (err error) { } // `iterInTbl` and `resTbl` are shared by multiple operators, // so will be closed when the SQL finishes. - if err = e.iterOutTbl.DerefAndClose(); err != nil { - return err + if e.iterOutTbl != nil { + if err = e.iterOutTbl.DerefAndClose(); err != nil { + return err + } } } diff --git a/executor/distsql.go b/executor/distsql.go index b6be537114d15..f3002d508688c 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -196,12 +196,14 @@ type IndexReaderExecutor struct { } // Close clears all resources hold by current object. -func (e *IndexReaderExecutor) Close() error { +func (e *IndexReaderExecutor) Close() (err error) { if e.table != nil && e.table.Meta().TempTableType != model.TempTableNone { return nil } - err := e.result.Close() + if e.result != nil { + err = e.result.Close() + } e.result = nil e.ctx.StoreQueryFeedback(e.feedback) return err diff --git a/executor/executor_test.go b/executor/executor_test.go index 4598bb9eb17fa..e022ca2dc6dea 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -21,6 +21,8 @@ import ( "math/rand" "net" "os" + "reflect" + "runtime" "strconv" "strings" "sync" @@ -8815,6 +8817,133 @@ func (s *testSerialSuite) TestIssue28650(c *C) { } } +// Test invoke Close without invoking Open before for each operators. +func (s *testSerialSuite) TestUnreasonablyClose(c *C) { + defer testleak.AfterTest(c)() + + is := infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable(), plannercore.MockUnsignedTable()}) + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + // To enable the shuffleExec operator. + _, err = se.Execute(context.Background(), "set @@tidb_merge_join_concurrency=4") + c.Assert(err, IsNil) + + var opsNeedsCovered = []plannercore.PhysicalPlan{ + &plannercore.PhysicalHashJoin{}, + &plannercore.PhysicalMergeJoin{}, + &plannercore.PhysicalIndexJoin{}, + &plannercore.PhysicalIndexHashJoin{}, + &plannercore.PhysicalTableReader{}, + &plannercore.PhysicalIndexReader{}, + &plannercore.PhysicalIndexLookUpReader{}, + &plannercore.PhysicalIndexMergeReader{}, + &plannercore.PhysicalApply{}, + &plannercore.PhysicalHashAgg{}, + &plannercore.PhysicalStreamAgg{}, + &plannercore.PhysicalLimit{}, + &plannercore.PhysicalSort{}, + &plannercore.PhysicalTopN{}, + &plannercore.PhysicalCTE{}, + &plannercore.PhysicalCTETable{}, + &plannercore.PhysicalMaxOneRow{}, + &plannercore.PhysicalProjection{}, + &plannercore.PhysicalSelection{}, + &plannercore.PhysicalTableDual{}, + &plannercore.PhysicalWindow{}, + &plannercore.PhysicalShuffle{}, + &plannercore.PhysicalUnionAll{}, + } + executorBuilder := executor.NewMockExecutorBuilderForTest(se, is, nil, math.MaxUint64, false, "global") + + var opsNeedsCoveredMask uint64 = 1< t1.a) AS a from t as t1) t", + "select /*+ hash_agg() */ count(f) from t group by a", + "select /*+ stream_agg() */ count(f) from t group by a", + "select * from t order by a, f", + "select * from t order by a, f limit 1", + "select * from t limit 1", + "select (select t1.a from t t1 where t1.a > t2.a) as a from t t2;", + "select a + 1 from t", + "select count(*) a from t having a > 1", + "select * from t where a = 1.1", + "with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 0) select * from cte1", + "select /*+use_index_merge(t, c_d_e, f)*/ * from t where c < 1 or f > 2", + "select sum(f) over (partition by f) from t", + "select /*+ merge_join(t1)*/ * from t t1 join t t2 on t1.d = t2.d", + "select a from t union all select a from t", + } { + comment := Commentf("case:%v sql:%s", i, tc) + c.Assert(err, IsNil, comment) + stmt, err := s.ParseOneStmt(tc, "", "") + c.Assert(err, IsNil, comment) + + err = se.NewTxn(context.Background()) + c.Assert(err, IsNil, comment) + p, _, err := planner.Optimize(context.TODO(), se, stmt, is) + c.Assert(err, IsNil, comment) + // This for loop level traverses the plan tree to get which operators are covered. + for child := []plannercore.PhysicalPlan{p.(plannercore.PhysicalPlan)}; len(child) != 0; { + newChild := make([]plannercore.PhysicalPlan, 0, len(child)) + for _, ch := range child { + found := false + for k, t := range opsNeedsCovered { + if reflect.TypeOf(t) == reflect.TypeOf(ch) { + opsAlreadyCoveredMask |= 1 << k + found = true + break + } + } + c.Assert(found, IsTrue, Commentf("case: %v sql: %s operator %v is not registered in opsNeedsCoveredMask", i, tc, reflect.TypeOf(ch))) + switch x := ch.(type) { + case *plannercore.PhysicalCTE: + newChild = append(newChild, x.RecurPlan) + newChild = append(newChild, x.SeedPlan) + continue + case *plannercore.PhysicalShuffle: + newChild = append(newChild, x.DataSources...) + newChild = append(newChild, x.Tails...) + continue + } + newChild = append(newChild, ch.Children()...) + } + child = newChild + } + + e := executorBuilder.Build(p) + + func() { + defer func() { + r := recover() + buf := make([]byte, 4096) + stackSize := runtime.Stack(buf, false) + buf = buf[:stackSize] + c.Assert(r, IsNil, Commentf("case: %v\n sql: %s\n error stack: %v", i, tc, string(buf))) + }() + c.Assert(e.Close(), IsNil, comment) + }() + } + // The following code is used to make sure all the operators registered + // in opsNeedsCoveredMask are covered. + commentBuf := strings.Builder{} + if opsAlreadyCoveredMask != opsNeedsCoveredMask { + for i := range opsNeedsCovered { + if opsAlreadyCoveredMask&(1<