diff --git a/executor/cte.go b/executor/cte.go index 7e1cb3797265f..c01aabcbc78e0 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -108,88 +108,29 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { req.Reset() e.resTbl.Lock() + defer func() { + e.resTbl.Unlock() + }() if !e.resTbl.Done() { - // e.resTbl and e.iterInTbl is shared by different CTEExec, so only setup once. setupCTEStorageTracker(e.resTbl, e.ctx) setupCTEStorageTracker(e.iterInTbl, e.ctx) - // Compute seed part. - e.curIter = 0 - e.iterInTbl.SetIter(e.curIter) - if e.curIter >= e.ctx.GetSessionVars().CTEMaxRecursionDepth { - return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter + 1) + if err = e.computeSeedPart(ctx); err != nil { + // Don't put it in defer. + // Because it should be called only when the filling is not completed. + if err1 := e.reopenTbls(); err1 != nil { + return err1 + } + return err } - for { - chk := newFirstChunk(e.seedExec) - if err = Next(ctx, e.seedExec, chk); err != nil { - return err - } - if chk.NumRows() == 0 { - break - } - if chk, err = e.iterInTbl.Add(chk); err != nil { - return err - } - if _, err = e.resTbl.Add(chk); err != nil { - return err - } - } - - // TODO: too tricky. This means iterInTbl fill done - e.curIter++ - close(e.iterInTbl.GetBegCh()) - - if e.recursiveExec != nil && e.iterInTbl.NumChunks() != 0 { - // Start to compute recursive part. Iteration 1 begins. - for { - chk := newFirstChunk(e.recursiveExec) - if err = Next(ctx, e.recursiveExec, chk); err != nil { - return err - } - if chk.NumRows() == 0 { - e.iterInTbl.ResetData() - for i := 0; i < e.iterOutTbl.NumChunks(); i++ { - if chk, err = e.iterOutTbl.GetChunk(i); err != nil { - return err - } - if chk, err = e.resTbl.Add(chk); err != nil { - return err - } - if _, err = e.iterInTbl.Add(chk); err != nil { - return err - } - } - if err = e.iterOutTbl.ResetData(); err != nil { - return err - } - if e.iterInTbl.NumChunks() == 0 { - break - } else { - if e.curIter >= e.ctx.GetSessionVars().CTEMaxRecursionDepth { - return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter + 1) - } - // Next iteration begins. Need use iterOutTbl as input of next iteration. - e.curIter++ - e.iterInTbl.SetIter(e.curIter) - // Make sure iterInTbl is setup before Close/Open, - // because some executors will read iterInTbl in Open() (like IndexLookupJoin). - if err = e.recursiveExec.Close(); err != nil { - return err - } - if err = e.recursiveExec.Open(ctx); err != nil { - return err - } - } - } else { - if _, err = e.iterOutTbl.Add(chk); err != nil { - return err - } - } - } - } - e.resTbl.SetDone() - } - e.resTbl.Unlock() + if err = e.computeRecursivePart(ctx); err != nil { + if err1 := e.reopenTbls(); err1 != nil { + return err1 + } + return err + } + } + e.resTbl.SetDone() if e.chkIdx < e.resTbl.NumChunks() { res, err := e.resTbl.GetChunk(e.chkIdx) @@ -226,11 +167,102 @@ func (e *CTEExec) Close() (err error) { return e.baseExecutor.Close() } +func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { + e.curIter = 0 + e.iterInTbl.SetIter(e.curIter) + if e.curIter >= e.ctx.GetSessionVars().CTEMaxRecursionDepth { + return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter + 1) + } + for { + chk := newFirstChunk(e.seedExec) + if err = Next(ctx, e.seedExec, chk); err != nil { + return err + } + if chk.NumRows() == 0 { + break + } + if chk, err = e.iterInTbl.Add(chk); err != nil { + return err + } + if _, err = e.resTbl.Add(chk); err != nil { + return err + } + } + e.curIter++ + + // TODO: This means iterInTbl fill done. But too tricky. + close(e.iterInTbl.GetBegCh()) + return nil +} + +func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { + if e.recursiveExec == nil || e.iterInTbl.NumChunks() == 0 { + return nil + } + + for { + chk := newFirstChunk(e.recursiveExec) + if err = Next(ctx, e.recursiveExec, chk); err != nil { + return err + } + if chk.NumRows() == 0 { + e.iterInTbl.ResetData() + for i := 0; i < e.iterOutTbl.NumChunks(); i++ { + if chk, err = e.iterOutTbl.GetChunk(i); err != nil { + return err + } + if chk, err = e.resTbl.Add(chk); err != nil { + return err + } + if _, err = e.iterInTbl.Add(chk); err != nil { + return err + } + } + if err = e.iterOutTbl.ResetData(); err != nil { + return err + } + if e.iterInTbl.NumChunks() == 0 { + break + } else { + if e.curIter >= e.ctx.GetSessionVars().CTEMaxRecursionDepth { + return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter + 1) + } + // Next iteration begins. Need use iterOutTbl as input of next iteration. + e.curIter++ + e.iterInTbl.SetIter(e.curIter) + // Make sure iterInTbl is setup before Close/Open, + // because some executors will read iterInTbl in Open() (like IndexLookupJoin). + if err = e.recursiveExec.Close(); err != nil { + return err + } + if err = e.recursiveExec.Open(ctx); err != nil { + return err + } + } + } else { + if _, err = e.iterOutTbl.Add(chk); err != nil { + return err + } + } + } + return nil +} + func (e *CTEExec) reset() { e.curIter = 0 e.chkIdx = 0 } +func (e *CTEExec) reopenTbls() (err error) { + if err := e.resTbl.Reopen(); err != nil { + return err + } + if err := e.iterInTbl.Reopen(); err != nil { + return err + } + return nil +} + func setupCTEStorageTracker(tbl CTEStorage, ctx sessionctx.Context) { memTracker := tbl.GetMemTracker() memTracker.SetLabel(memory.LabelForCTEStorage) diff --git a/executor/cte_storage.go b/executor/cte_storage.go index e4e62c95fdc23..e274cdc7c9a61 100644 --- a/executor/cte_storage.go +++ b/executor/cte_storage.go @@ -70,6 +70,7 @@ type CTEStorage interface { ForceClose() error ResetData() error + Reopen() error NumChunks() int GetMemTracker() *memory.Tracker @@ -84,15 +85,18 @@ type CTEStorage interface { type CTEStorageRC struct { // meta info mu sync.Mutex - begCh chan struct{} refCnt int - done bool - iter int filterDup bool sc *stmtctx.StatementContext + tp []*types.FieldType + chkSize int // data info - tp []*types.FieldType + begCh chan struct{} + done bool + iter int + + // data rc *chunk.RowContainer // TODO: also track mem usage of ht ht baseHashTable @@ -109,6 +113,7 @@ func (s *CTEStorageRC) OpenAndRef(fieldType []*types.FieldType, chkSize int) (er return errors.Trace(errors.New("chunk field types are nil")) } s.tp = fieldType + s.chkSize = chkSize s.rc = chunk.NewRowContainer(fieldType, chkSize) s.refCnt = 1 s.begCh = make(chan struct{}) @@ -222,6 +227,22 @@ func (s *CTEStorageRC) ResetData() error { return s.rc.Reset() } +func (s *CTEStorageRC) Reopen() (err error) { + if s.filterDup { + s.ht = newConcurrentMapHashTable() + } + if err = s.rc.Reset(); err != nil { + return err + } + s.iter = 0 + s.begCh = make(chan struct{}) + s.done = false + // Create a new RowContainer. Because some meta infos in old RowContainer are not resetted. + // Such as memTracker/actionSpill etc. So we just use a new one. + s.rc = chunk.NewRowContainer(s.tp, s.chkSize) + return nil +} + func (s *CTEStorageRC) NumChunks() int { return s.rc.NumChunks() }