Skip to content

Commit

Permalink
sql: implement WITH RECURSIVE with UNION
Browse files Browse the repository at this point in the history
This change implements the UNION variant of WITH RECURSIVE, where rows
are deduplicated. We achieve this by storing all rows in a
deduplicating container and inserting in that container first,
detecting if the row is a duplicate.

Fixes cockroachdb#46642.

Release note (sql change): The WITH RECURSIVE variant that uses UNION
(as opposed to UNION ALL) is now supported.
  • Loading branch information
RaduBerinde committed Oct 19, 2021
1 parent 38168db commit 9a1f644
Show file tree
Hide file tree
Showing 15 changed files with 264 additions and 77 deletions.
16 changes: 7 additions & 9 deletions pkg/sql/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (a *applyJoinNode) startExec(params runParams) error {
}
}
a.run.out = make(tree.Datums, len(a.columns))
a.run.rightRows.init(a.rightTypes, params.extendedEvalCtx, "apply-join" /* opName */)
a.run.rightRows.init(a.rightTypes, params.extendedEvalCtx, "apply-join" /* opName */, false /* enableDeduplication */)
return nil
}

Expand Down Expand Up @@ -243,7 +243,8 @@ func (a *applyJoinNode) clearRightRows(params runParams) error {
// wrong during execution of the right hand side of the join, and that we should
// completely give up on the outer join.
func (a *applyJoinNode) runRightSidePlan(params runParams, plan *planComponents) error {
if err := runPlanInsidePlan(params, plan, &a.run.rightRows); err != nil {
rowResultWriter := NewRowResultWriter(&a.run.rightRows)
if err := runPlanInsidePlan(params, plan, rowResultWriter); err != nil {
return err
}
a.run.rightRowsIterator = newRowContainerIterator(params.ctx, a.run.rightRows, a.rightTypes)
Expand All @@ -252,12 +253,9 @@ func (a *applyJoinNode) runRightSidePlan(params runParams, plan *planComponents)

// runPlanInsidePlan is used to run a plan and gather the results in a row
// container, as part of the execution of an "outer" plan.
func runPlanInsidePlan(
params runParams, plan *planComponents, rowContainer *rowContainerHelper,
) error {
rowResultWriter := NewRowResultWriter(rowContainer)
func runPlanInsidePlan(params runParams, plan *planComponents, resultWriter rowResultWriter) error {
recv := MakeDistSQLReceiver(
params.ctx, rowResultWriter, tree.Rows,
params.ctx, resultWriter, tree.Rows,
params.ExecCfg().RangeDescriptorCache,
params.p.Txn(),
params.ExecCfg().Clock,
Expand Down Expand Up @@ -298,7 +296,7 @@ func runPlanInsidePlan(
recv,
&subqueryResultMemAcc,
) {
return rowResultWriter.Err()
return resultWriter.Err()
}
}

Expand All @@ -318,7 +316,7 @@ func runPlanInsidePlan(
params.p.extendedEvalCtx.ExecCfg.DistSQLPlanner.PlanAndRun(
params.ctx, evalCtx, planCtx, params.p.Txn(), plan.main, recv,
)()
return rowResultWriter.Err()
return resultWriter.Err()
}

func (a *applyJoinNode) Values() tree.Datums {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type bufferNode struct {

func (n *bufferNode) startExec(params runParams) error {
n.typs = planTypes(n.plan)
n.rows.init(n.typs, params.extendedEvalCtx, n.label)
n.rows.init(n.typs, params.extendedEvalCtx, n.label, false /* enableDeduplication */)
return nil
}

Expand Down
37 changes: 33 additions & 4 deletions pkg/sql/buffer_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/rowenc"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/cockroachdb/cockroach/pkg/util/mon"
)

Expand All @@ -34,30 +35,58 @@ type rowContainerHelper struct {
}

func (c *rowContainerHelper) init(
typs []*types.T, evalContext *extendedEvalContext, opName string,
typs []*types.T, evalContext *extendedEvalContext, opName string, enableDeduplication bool,
) {
distSQLCfg := &evalContext.DistSQLPlanner.distSQLSrv.ServerConfig
c.memMonitor = execinfra.NewLimitedMonitorNoFlowCtx(
evalContext.Context, evalContext.Mon, distSQLCfg, evalContext.SessionData(),
fmt.Sprintf("%s-limited", opName),
)
c.diskMonitor = execinfra.NewMonitor(evalContext.Context, distSQLCfg.ParentDiskMonitor, fmt.Sprintf("%s-disk", opName))
c.diskMonitor = execinfra.NewMonitor(
evalContext.Context, distSQLCfg.ParentDiskMonitor, fmt.Sprintf("%s-disk", opName),
)
c.rows = &rowcontainer.DiskBackedRowContainer{}
ordering := colinfo.NoOrdering
if enableDeduplication {
ordering = make(colinfo.ColumnOrdering, len(typs))
for i := range ordering {
ordering[i].ColIdx = i
ordering[i].Direction = encoding.Ascending
}
}
c.rows.Init(
colinfo.NoOrdering, typs, &evalContext.EvalContext,
ordering, typs, &evalContext.EvalContext,
distSQLCfg.TempStorage, c.memMonitor, c.diskMonitor,
)
if enableDeduplication {
c.rows.DoDeDuplicate()
}
c.scratch = make(rowenc.EncDatumRow, len(typs))
}

// addRow adds a given row to the helper and returns any error it encounters.
// addRow adds a given row to the container.
func (c *rowContainerHelper) addRow(ctx context.Context, row tree.Datums) error {
for i := range row {
c.scratch[i].Datum = row[i]
}
return c.rows.AddRow(ctx, c.scratch)
}

// addRowWithDedup adds a given row if not already present in the container.
// To use this method, init must have been called with enableDeduplication=true.
func (c *rowContainerHelper) addRowWithDedup(
ctx context.Context, row tree.Datums,
) (ok bool, _ error) {
for i := range row {
c.scratch[i].Datum = row[i]
}
lenBefore := c.rows.Len()
if _, err := c.rows.AddRowWithDeDup(ctx, c.scratch); err != nil {
return false, err
}
return c.rows.Len() > lenBefore, nil
}

// len returns the number of rows buffered so far.
func (c *rowContainerHelper) len() int {
return c.rows.Len()
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/distsql_running.go
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ func (dsp *DistSQLPlanner) planAndRunSubquery(
typs = subqueryPhysPlan.GetResultTypes()
}
var rows rowContainerHelper
rows.init(typs, evalCtx, "subquery" /* opName */)
rows.init(typs, evalCtx, "subquery" /* opName */, false /* enableDeduplication */)
defer rows.close(ctx)

// TODO(yuzefovich): consider implementing batch receiving result writer.
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/distsql_spec_exec_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ func (e *distSQLSpecExecFactory) ConstructScanBuffer(
}

func (e *distSQLSpecExecFactory) ConstructRecursiveCTE(
initial exec.Node, fn exec.RecursiveCTEIterationFn, label string,
initial exec.Node, fn exec.RecursiveCTEIterationFn, label string, deduplicate bool,
) (exec.Node, error) {
return nil, unimplemented.NewWithIssue(47473, "experimental opt-driven distsql planning: recursive CTE")
}
Expand Down
99 changes: 85 additions & 14 deletions pkg/sql/logictest/testdata/logic_test/with
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,19 @@ EXECUTE z5(3, 5)
3
5

# TODO(justin): re-enable this, we don't allow WITHs having outer columns.
# statement ok
# PREPARE z6(int) AS
# SELECT * FROM
# (VALUES (1), (2)) v(x),
# LATERAL (SELECT * FROM
# (WITH foo AS (SELECT $1 + x) SELECT * FROM foo)
# )

# query II
# EXECUTE z6(3)
# ----
# 1 4
# 2 5
statement ok
PREPARE z6(int) AS
SELECT * FROM
(VALUES (1), (2)) v(x),
LATERAL (SELECT * FROM
(WITH foo AS (SELECT $1 + x) SELECT * FROM foo)
)

query II
EXECUTE z6(3)
----
1 4
2 5

# Recursive CTE example from postgres docs.
query T
Expand All @@ -384,6 +383,18 @@ SELECT sum(n) FROM t
----
5050

# Similar example where many duplicate rows are generated but we use UNION to
# deduplicate them.
query T
WITH RECURSIVE t(n) AS (
VALUES (1)
UNION
SELECT n+y FROM t, (VALUES (1), (2)) AS v(y) WHERE n < 99
)
SELECT sum(n) FROM t
----
5050

# Test where initial query has duplicate columns.
query II
WITH RECURSIVE cte(a, b) AS (
Expand Down Expand Up @@ -554,6 +565,66 @@ WITH RECURSIVE points AS (
··························································································································
··························································································································

# Test that we deduplicate rows from the initial expression.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (2, 2), (1, 1), (1, 2), (1, 1), (1, 3), (1, 2), (2, 2)
UNION
SELECT a+10, b+10 FROM cte WHERE a < 20
) SELECT * FROM cte;
----
2 2
1 1
1 2
1 3
12 12
11 11
11 12
11 13
22 22
21 21
21 22
21 23

# Test that we deduplicate rows from a single iteration.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (1, 1), (1, 2), (2, 2)
UNION
SELECT 4-a, 4-a FROM cte
) SELECT * FROM cte;
----
1 1
1 2
2 2
3 3

# Test that we deduplicate rows across iterations.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (1, 1), (1, 2), (2, 2)
UNION
SELECT (a+i) % 4, (b+1-i) % 4 FROM cte, (VALUES (0), (1)) AS v(i)
) SELECT * FROM cte;
----
1 1
1 2
2 2
2 1
1 3
2 3
3 2
3 1
1 0
2 0
3 3
0 2
0 1
3 0
0 3
0 0


# Regression test for #45869 (CTE inside recursive CTE).
query T rowsort
WITH RECURSIVE x(a) AS (
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/exec/execbuilder/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,7 @@ func (b *Builder) buildRecursiveCTE(rec *memo.RecursiveCTEExpr) (execPlan, error

label := fmt.Sprintf("working buffer (%s)", rec.Name)
var ep execPlan
ep.root, err = b.factory.ConstructRecursiveCTE(initial.root, fn, label)
ep.root, err = b.factory.ConstructRecursiveCTE(initial.root, fn, label, rec.Deduplicate)
if err != nil {
return execPlan{}, err
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/sql/opt/exec/factory.opt
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,13 @@ define ScanBuffer {
# function. The returned plan uses this reference with a
# ConstructScanBuffer call.
# - the plan is executed; the results are emitted and also saved in a new
# buffer for the next iteration.
# buffer for the next iteration. If Deduplicate is true, only rows that
# haven't been returned yet are saved.
define RecursiveCTE {
Initial exec.Node
Fn exec.RecursiveCTEIterationFn
Label string
Deduplicate bool
}

# ControlJobs implements PAUSE/CANCEL/RESUME JOBS.
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/opt/memo/expr_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {

case *RecursiveCTEExpr:
if !f.HasFlags(ExprFmtHideColumns) {
if t.Deduplicate {
tp.Childf("deduplicate")
}
tp.Childf("working table binding: &%d", t.WithID)
f.formatColList(e, tp, "initial columns:", t.InitialCols)
f.formatColList(e, tp, "recursive columns:", t.RecursiveCols)
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/opt/ops/relational.opt
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,10 @@ define RecursiveCTEPrivate {
# These columns are also used by the Recursive query to refer to the working
# table (see WithID).
OutCols ColList

# Deduplicate indicates if output rows should be deduplicated against all
# previous output rows (UNION variant, not UNION ALL).
Deduplicate bool
}

# FakeRel is a mock relational operator used for testing and as a dummy binding
Expand Down
Loading

0 comments on commit 9a1f644

Please sign in to comment.