Skip to content

Commit

Permalink
support limit in recursive cte
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlinjunhong committed Aug 28, 2023
1 parent 31b78e6 commit 72a0cb2
Show file tree
Hide file tree
Showing 11 changed files with 775 additions and 671 deletions.
1,226 changes: 631 additions & 595 deletions pkg/pb/plan/plan.pb.go

Large diffs are not rendered by default.

22 changes: 21 additions & 1 deletion pkg/sql/colexec/dispatch/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package dispatch
import (
"bytes"
"context"
"fmt"
"github.com/google/uuid"
"github.com/matrixorigin/matrixone/pkg/common/moerr"
"github.com/matrixorigin/matrixone/pkg/container/batch"
"github.com/matrixorigin/matrixone/pkg/container/types"
"github.com/matrixorigin/matrixone/pkg/container/vector"
"github.com/matrixorigin/matrixone/pkg/logutil"
"github.com/matrixorigin/matrixone/pkg/sql/colexec"
"github.com/matrixorigin/matrixone/pkg/vm/process"
Expand Down Expand Up @@ -92,7 +95,12 @@ func Prepare(proc *process.Process, arg any) error {
func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (process.ExecStatus, error) {
ap := arg.(*Argument)
bat := proc.InputBatch()
if bat == nil {
if ap.RecSink {
fmt.Println("hello world")
}
if bat == nil && ap.RecSink {
bat = makeEndBatch(proc)
} else if bat == nil {
return process.ExecStop, nil
}
if bat.Last() {
Expand All @@ -116,6 +124,18 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
}
}

func makeEndBatch(proc *process.Process) *batch.Batch {
b := batch.NewWithSize(1)
b.Attrs = []string{
"recursive_col",
}
b.SetVector(0, vector.NewVec(types.T_varchar.ToType()))
vector.AppendBytes(b.GetVector(0), []byte("check recursive status"), false, proc.GetMPool())
batch.SetLength(b, 1)
b.SetEnd()
return b
}

func (arg *Argument) waitRemoteRegsReady(proc *process.Process) (bool, error) {
cnt := len(arg.RemoteRegs)
for cnt > 0 {
Expand Down
51 changes: 42 additions & 9 deletions pkg/sql/colexec/mergecte/mergecte.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package mergecte
import (
"bytes"
"github.com/matrixorigin/matrixone/pkg/container/batch"
"github.com/matrixorigin/matrixone/pkg/container/types"
"github.com/matrixorigin/matrixone/pkg/container/vector"
"github.com/matrixorigin/matrixone/pkg/vm/process"
)

Expand All @@ -28,8 +30,9 @@ func Prepare(proc *process.Process, arg any) error {
ap := arg.(*Argument)
ap.ctr = new(container)
ap.ctr.InitReceiver(proc, true)
ap.ctr.nodeCnt = int32(len(proc.Reg.MergeReceivers))
ap.ctr.nodeCnt = int32(len(proc.Reg.MergeReceivers)) - 1
ap.ctr.curNodeCnt = ap.ctr.nodeCnt
ap.ctr.status = sendInitial
return nil
}

Expand All @@ -41,15 +44,35 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
ctr := ap.ctr
var sb *batch.Batch
var end bool
var err error

for {
sb, end, _ = ctr.ReceiveFromAllRegs(anal)
if end {
proc.SetInputBatch(nil)
return process.ExecStop, nil
switch ctr.status {
case sendInitial:
sb, _, err = ctr.ReceiveFromSingleReg(0, anal)
if err != nil {
return process.ExecStop, err
}
if sb == nil {
ctr.status = sendLastTag
}
fallthrough
case sendLastTag:
if ctr.status == sendLastTag {
ctr.status = sendRecursive
sb = makeRecursiveBatch(proc)
ctr.RemoveChosen(1)
}
case sendRecursive:
for {
sb, end, _ = ctr.ReceiveFromAllRegs(anal)
if sb == nil || end {
proc.SetInputBatch(nil)
return process.ExecStop, nil
}
if !sb.Last() {
break
}

if sb.Last() {
sb.SetLast()
ap.ctr.curNodeCnt--
if ap.ctr.curNodeCnt == 0 {
Expand All @@ -58,8 +81,6 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
} else {
proc.PutBatch(sb)
}
} else {
break
}
}

Expand All @@ -68,3 +89,15 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
proc.SetInputBatch(sb)
return process.ExecNext, nil
}

func makeRecursiveBatch(proc *process.Process) *batch.Batch {
b := batch.NewWithSize(1)
b.Attrs = []string{
"recursive_col",
}
b.SetVector(0, vector.NewVec(types.T_varchar.ToType()))
vector.AppendBytes(b.GetVector(0), []byte("check recursive status"), false, proc.GetMPool())
batch.SetLength(b, 1)
b.SetLast()
return b
}
7 changes: 7 additions & 0 deletions pkg/sql/colexec/mergecte/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@ import (
"github.com/matrixorigin/matrixone/pkg/vm/process"
)

const (
sendInitial = 0
sendLastTag = 1
sendRecursive = 2
)

type container struct {
colexec.ReceiverOperator
nodeCnt int32
curNodeCnt int32
status int32
}

type Argument struct {
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/colexec/mergelimit/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
proc.SetInputBatch(nil)
return process.ExecStop, nil
}
if bat.Last() {
proc.SetInputBatch(bat)
return process.ExecNext, nil
}

anal.Input(bat, isFirst)
if ap.ctr.seen >= ap.Limit {
Expand Down
115 changes: 58 additions & 57 deletions pkg/sql/colexec/mergerecursive/mergerecursive.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,70 +41,71 @@ func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (p
ctr := ap.ctr
var sb *batch.Batch

bat, end, err := ctr.ReceiveFromSingleRegNonBlock(1, anal)
if err != nil {
return process.ExecStop, err
}
if end {
proc.SetInputBatch(nil)
return process.ExecStop, nil
}
if bat != nil {
//b, err := bat.Dup(proc.GetMPool())
//if err != nil {
// return process.ExecStop, err
//}
//b.Recursive = bat.Recursive
ctr.bats = append(ctr.bats, bat)
if bat.Last() {
ctr.last = true
}
}
//bat, end, err := ctr.ReceiveFromSingleRegNonBlock(1, anal)
//if err != nil {
// return process.ExecStop, err
//}
//if end {
// proc.SetInputBatch(nil)
// return process.ExecStop, nil
//}
//if bat != nil {
// ctr.bats = append(ctr.bats, bat)
// if bat.Last() {
// ctr.last = true
// }
//}
//
//switch ctr.status {
//case SendA:
// sb, _, err = ctr.ReceiveFromSingleReg(0, anal)
// if err != nil {
// return process.ExecStop, err
// }
// if sb == nil {
// ctr.status = SendLastBatch
// }
// fallthrough
//case SendLastBatch:
// if ctr.status == SendLastBatch {
// ctr.status = SendB
// sb = makeRecursiveBatch(proc)
// }
//case SendB:
// for !ctr.last {
// bat, _, err = ctr.ReceiveFromSingleReg(1, anal)
// if err != nil {
// return process.ExecStop, err
// }
// if bat == nil || bat.End() {
// proc.SetInputBatch(nil)
// return process.ExecStop, nil
// }
// if bat.Last() {
// ctr.last = true
// }
// ctr.bats = append(ctr.bats, bat)
// }
// sb = ctr.bats[0]
// ctr.bats = ctr.bats[1:]
//}

switch ctr.status {
case SendA:
sb, _, err = ctr.ReceiveFromSingleReg(0, anal)
for !ctr.last {
bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
if err != nil {
return process.ExecStop, err
}
if sb == nil {
ctr.status = SendLastBatch
}
//else {
// sb, err = bat.Dup(proc.Mp())
// if err != nil {
// return process.ExecStop, err
// }
//}
fallthrough
case SendLastBatch:
if ctr.status == SendLastBatch {
ctr.status = SendB
sb = makeRecursiveBatch(proc)
if bat == nil || bat.End() {
proc.SetInputBatch(nil)
return process.ExecStop, nil
}
case SendB:
for !ctr.last {
bat, _, err = ctr.ReceiveFromSingleReg(1, anal)
if err != nil {
return process.ExecStop, err
}
if bat == nil {
proc.SetInputBatch(nil)
return process.ExecStop, nil
}
//b, err := bat.Dup(proc.Mp())
//if err != nil {
// return process.ExecStop, err
//}
//b.Recursive = bat.Recursive
ctr.bats = append(ctr.bats, bat)
if bat.Last() {
ctr.last = true
}
if bat.Last() {
ctr.last = true
}
sb = ctr.bats[0]
ctr.bats = ctr.bats[1:]
ctr.bats = append(ctr.bats, bat)
}
sb = ctr.bats[0]
ctr.bats = ctr.bats[1:]

if sb.Last() {
ctr.last = false
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/colexec/receiver_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ func (r *ReceiverOperator) ReceiveFromAllRegs(analyze process.Analyze) (*batch.B

if !ok {
logutil.Errorf("children pipeline closed unexpectedly")
r.removeChosen(chosen)
r.RemoveChosen(chosen)
return nil, true, nil
}

bat := (*batch.Batch)(value.UnsafePointer())
if bat == nil {
r.removeChosen(chosen)
r.RemoveChosen(chosen)
continue
}

Expand Down Expand Up @@ -154,7 +154,7 @@ func (r *ReceiverOperator) FreeMergeTypeOperator(failed bool) {
}
}

func (r *ReceiverOperator) removeChosen(idx int) {
func (r *ReceiverOperator) RemoveChosen(idx int) {
r.receiverListener = append(r.receiverListener[:idx], r.receiverListener[idx+1:]...)
r.aliveMergeReceiver--
}
2 changes: 1 addition & 1 deletion pkg/sql/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ func (c *Compile) compilePlanScope(ctx context.Context, step int32, curNodeIdx i
rs := c.newMergeScope(ss)
rs.appendInstruction(vm.Instruction{
Op: vm.Dispatch,
Arg: constructDispatchLocal(true, true, len(receivers) > 1, receivers),
Arg: constructDispatchLocal(true, true, n.RecursiveSink, receivers),
})

return []*Scope{rs}, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/plan/build_dml_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,7 @@ func collectSinkAndSinkScanMeta(
} else {
sinks[nodeId].step = oldStep
}
} else if node.NodeType == plan.Node_SINK_SCAN || node.NodeType == plan.Node_RECURSIVE_CTE {
} else if node.NodeType == plan.Node_SINK_SCAN || node.NodeType == plan.Node_RECURSIVE_CTE || node.NodeType == plan.Node_RECURSIVE_SCAN {
sinkNodeId := qry.Steps[node.SourceStep[0]]
if _, ok := sinks[sinkNodeId]; !ok {
sinks[sinkNodeId] = &sinkMeta{
Expand Down
10 changes: 6 additions & 4 deletions pkg/sql/plan/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2762,20 +2762,22 @@ func (builder *QueryBuilder) buildTable(stmt tree.TableExpr, ctx *BindContext, p
}

sourceStep := builder.appendStep(recursiveLastNodeID)
nodeID = appendCTEScanNode(builder, ctx, sourceStep, initCtx.sinkTag)
nodeID = appendCTEScanNode(builder, ctx, initSourceStep, initCtx.sinkTag)
if limitExpr != nil || offsetExpr != nil {
node := builder.qry.Nodes[nodeID]
node.Limit = limitExpr
node.Offset = offsetExpr
}
for i := 0; i < len(recursiveSteps)-1; i++ {
for i := 0; i < len(recursiveSteps); i++ {
builder.qry.Nodes[nodeID].SourceStep = append(builder.qry.Nodes[nodeID].SourceStep, recursiveSteps[i])
}
curStep := int32(len(builder.qry.Steps))
for _, id := range recursiveNodeIDs {
builder.qry.Nodes[id].SourceStep = append(builder.qry.Nodes[id].SourceStep, curStep)
// builder.qry.Nodes[id].SourceStep = append(builder.qry.Nodes[id].SourceStep, curStep)
builder.qry.Nodes[id].SourceStep[0] = curStep
}
unionAllLastNodeID := appendSinkNodeWithTag(builder, ctx, nodeID, ctx.sinkTag)
builder.qry.Nodes[unionAllLastNodeID].RecursiveSink = true

// final statement
ctx.finalSelect = true
Expand All @@ -2786,7 +2788,7 @@ func (builder *QueryBuilder) buildTable(stmt tree.TableExpr, ctx *BindContext, p
}
sourceStep = builder.appendStep(unionAllLastNodeID)
nodeID = appendSinkScanNodeWithTag(builder, ctx, sourceStep, initCtx.sinkTag)
builder.qry.Nodes[nodeID].SourceStep = append(builder.qry.Nodes[nodeID].SourceStep, initSourceStep)
// builder.qry.Nodes[nodeID].SourceStep = append(builder.qry.Nodes[nodeID].SourceStep, initSourceStep)
}

break
Expand Down
1 change: 1 addition & 0 deletions proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ message Node {
repeated RuntimeFilterSpec runtime_filter_build_list = 41;

bytes uuid = 42;
bool recursive_sink = 43;
}

message LockTarget {
Expand Down

0 comments on commit 72a0cb2

Please sign in to comment.