Skip to content

Commit

Permalink
sql: implement WITH RECURSIVE with UNION
Browse files Browse the repository at this point in the history
This change implements the UNION variant of WITH RECURSIVE, where rows
are deduplicated. We achieve this by storing all rows in a
deduplicating container and inserting in that container first,
detecting if the row is a duplicate.

Fixes #46642.

Release note (sql change): The WITH RECURSIVE variant that uses UNION
(as opposed to UNION ALL) is now supported.
  • Loading branch information
RaduBerinde committed Oct 20, 2021
1 parent 6e110fe commit e030c45
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 88 deletions.
18 changes: 8 additions & 10 deletions pkg/sql/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,19 @@ func (a *applyJoinNode) clearRightRows(params runParams) error {
// wrong during execution of the right hand side of the join, and that we should
// completely give up on the outer join.
func (a *applyJoinNode) runRightSidePlan(params runParams, plan *planComponents) error {
if err := runPlanInsidePlan(params, plan, &a.run.rightRows); err != nil {
rowResultWriter := NewRowResultWriter(&a.run.rightRows)
if err := runPlanInsidePlan(params, plan, rowResultWriter); err != nil {
return err
}
a.run.rightRowsIterator = newRowContainerIterator(params.ctx, a.run.rightRows, a.rightTypes)
return nil
}

// runPlanInsidePlan is used to run a plan and gather the results in a row
// container, as part of the execution of an "outer" plan.
func runPlanInsidePlan(
params runParams, plan *planComponents, rowContainer *rowContainerHelper,
) error {
rowResultWriter := NewRowResultWriter(rowContainer)
// runPlanInsidePlan is used to run a plan and gather the results in the
// resultWriter, as part of the execution of an "outer" plan.
func runPlanInsidePlan(params runParams, plan *planComponents, resultWriter rowResultWriter) error {
recv := MakeDistSQLReceiver(
params.ctx, rowResultWriter, tree.Rows,
params.ctx, resultWriter, tree.Rows,
params.ExecCfg().RangeDescriptorCache,
params.p.Txn(),
params.ExecCfg().Clock,
Expand Down Expand Up @@ -298,7 +296,7 @@ func runPlanInsidePlan(
recv,
&subqueryResultMemAcc,
) {
return rowResultWriter.Err()
return resultWriter.Err()
}
}

Expand All @@ -318,7 +316,7 @@ func runPlanInsidePlan(
params.p.extendedEvalCtx.ExecCfg.DistSQLPlanner.PlanAndRun(
params.ctx, evalCtx, planCtx, params.p.Txn(), plan.main, recv,
)()
return rowResultWriter.Err()
return resultWriter.Err()
}

func (a *applyJoinNode) Values() tree.Datums {
Expand Down
65 changes: 56 additions & 9 deletions pkg/sql/buffer_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,25 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/rowenc"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/cockroachdb/cockroach/pkg/util/mon"
)

// rowContainerHelper is a wrapper around a disk-backed row container that
// should be used by planNodes (or similar components) whenever they need to
// buffer data. init must be called before the first use.
// buffer data. init or initWithDedup must be called before the first use.
type rowContainerHelper struct {
rows *rowcontainer.DiskBackedRowContainer
scratch rowenc.EncDatumRow
memMonitor *mon.BytesMonitor
diskMonitor *mon.BytesMonitor
rows *rowcontainer.DiskBackedRowContainer
scratch rowenc.EncDatumRow
}

func (c *rowContainerHelper) init(
typs []*types.T, evalContext *extendedEvalContext, opName string,
) {
c.initMonitors(evalContext, opName)
distSQLCfg := &evalContext.DistSQLPlanner.distSQLSrv.ServerConfig
c.memMonitor = execinfra.NewLimitedMonitorNoFlowCtx(
evalContext.Context, evalContext.Mon, distSQLCfg, evalContext.SessionData(),
fmt.Sprintf("%s-limited", opName),
)
c.diskMonitor = execinfra.NewMonitor(evalContext.Context, distSQLCfg.ParentDiskMonitor, fmt.Sprintf("%s-disk", opName))
c.rows = &rowcontainer.DiskBackedRowContainer{}
c.rows.Init(
colinfo.NoOrdering, typs, &evalContext.EvalContext,
Expand All @@ -50,14 +47,64 @@ func (c *rowContainerHelper) init(
c.scratch = make(rowenc.EncDatumRow, len(typs))
}

// addRow adds a given row to the helper and returns any error it encounters.
// initWithDedup is a variant of init that is used if row deduplication
// functionality is needed (see addRowWithDedup).
func (c *rowContainerHelper) initWithDedup(
typs []*types.T, evalContext *extendedEvalContext, opName string,
) {
c.initMonitors(evalContext, opName)
distSQLCfg := &evalContext.DistSQLPlanner.distSQLSrv.ServerConfig
c.rows = &rowcontainer.DiskBackedRowContainer{}
// The DiskBackedRowContainer can be configured to deduplicate along the
// columns in the ordering (these columns form the "key" if the container has
// to spill to disk).
ordering := make(colinfo.ColumnOrdering, len(typs))
for i := range ordering {
ordering[i].ColIdx = i
ordering[i].Direction = encoding.Ascending
}
c.rows.Init(
ordering, typs, &evalContext.EvalContext,
distSQLCfg.TempStorage, c.memMonitor, c.diskMonitor,
)
c.rows.DoDeDuplicate()
c.scratch = make(rowenc.EncDatumRow, len(typs))
}

func (c *rowContainerHelper) initMonitors(evalContext *extendedEvalContext, opName string) {
distSQLCfg := &evalContext.DistSQLPlanner.distSQLSrv.ServerConfig
c.memMonitor = execinfra.NewLimitedMonitorNoFlowCtx(
evalContext.Context, evalContext.Mon, distSQLCfg, evalContext.SessionData(),
fmt.Sprintf("%s-limited", opName),
)
c.diskMonitor = execinfra.NewMonitor(
evalContext.Context, distSQLCfg.ParentDiskMonitor, fmt.Sprintf("%s-disk", opName),
)
}

// addRow adds the given row to the container.
func (c *rowContainerHelper) addRow(ctx context.Context, row tree.Datums) error {
for i := range row {
c.scratch[i].Datum = row[i]
}
return c.rows.AddRow(ctx, c.scratch)
}

// addRowWithDedup adds the given row if not already present in the container.
// To use this method, initWithDedup must be used first.
func (c *rowContainerHelper) addRowWithDedup(
ctx context.Context, row tree.Datums,
) (added bool, _ error) {
for i := range row {
c.scratch[i].Datum = row[i]
}
lenBefore := c.rows.Len()
if _, err := c.rows.AddRowWithDeDup(ctx, c.scratch); err != nil {
return false, err
}
return c.rows.Len() > lenBefore, nil
}

// len returns the number of rows buffered so far.
func (c *rowContainerHelper) len() int {
return c.rows.Len()
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/distsql_spec_exec_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ func (e *distSQLSpecExecFactory) ConstructScanBuffer(
}

func (e *distSQLSpecExecFactory) ConstructRecursiveCTE(
initial exec.Node, fn exec.RecursiveCTEIterationFn, label string,
initial exec.Node, fn exec.RecursiveCTEIterationFn, label string, deduplicate bool,
) (exec.Node, error) {
return nil, unimplemented.NewWithIssue(47473, "experimental opt-driven distsql planning: recursive CTE")
}
Expand Down
99 changes: 85 additions & 14 deletions pkg/sql/logictest/testdata/logic_test/with
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,19 @@ EXECUTE z5(3, 5)
3
5

# TODO(justin): re-enable this, we don't allow WITHs having outer columns.
# statement ok
# PREPARE z6(int) AS
# SELECT * FROM
# (VALUES (1), (2)) v(x),
# LATERAL (SELECT * FROM
# (WITH foo AS (SELECT $1 + x) SELECT * FROM foo)
# )

# query II
# EXECUTE z6(3)
# ----
# 1 4
# 2 5
statement ok
PREPARE z6(int) AS
SELECT * FROM
(VALUES (1), (2)) v(x),
LATERAL (SELECT * FROM
(WITH foo AS (SELECT $1 + x) SELECT * FROM foo)
)

query II
EXECUTE z6(3)
----
1 4
2 5

# Recursive CTE example from postgres docs.
query T
Expand All @@ -384,6 +383,18 @@ SELECT sum(n) FROM t
----
5050

# Similar example where many duplicate rows are generated but we use UNION to
# deduplicate them.
query T
WITH RECURSIVE t(n) AS (
VALUES (1)
UNION
SELECT n+y FROM t, (VALUES (1), (2)) AS v(y) WHERE n < 99
)
SELECT sum(n) FROM t
----
5050

# Test where initial query has duplicate columns.
query II
WITH RECURSIVE cte(a, b) AS (
Expand Down Expand Up @@ -554,6 +565,66 @@ WITH RECURSIVE points AS (
··························································································································
··························································································································

# Test that we deduplicate rows from the initial expression.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (2, 2), (1, 1), (1, 2), (1, 1), (1, 3), (1, 2), (2, 2)
UNION
SELECT a+10, b+10 FROM cte WHERE a < 20
) SELECT * FROM cte;
----
2 2
1 1
1 2
1 3
12 12
11 11
11 12
11 13
22 22
21 21
21 22
21 23

# Test that we deduplicate rows from a single iteration.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (1, 1), (1, 2), (2, 2)
UNION
SELECT 4-a, 4-a FROM cte
) SELECT * FROM cte;
----
1 1
1 2
2 2
3 3

# Test that we deduplicate rows across iterations.
query II rowsort
WITH RECURSIVE cte(a, b) AS (
VALUES (1, 1), (1, 2), (2, 2)
UNION
SELECT (a+i) % 4, (b+1-i) % 4 FROM cte, (VALUES (0), (1)) AS v(i)
) SELECT * FROM cte;
----
1 1
1 2
2 2
2 1
1 3
2 3
3 2
3 1
1 0
2 0
3 3
0 2
0 1
3 0
0 3
0 0


# Regression test for #45869 (CTE inside recursive CTE).
query T rowsort
WITH RECURSIVE x(a) AS (
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/exec/execbuilder/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,7 @@ func (b *Builder) buildRecursiveCTE(rec *memo.RecursiveCTEExpr) (execPlan, error

label := fmt.Sprintf("working buffer (%s)", rec.Name)
var ep execPlan
ep.root, err = b.factory.ConstructRecursiveCTE(initial.root, fn, label)
ep.root, err = b.factory.ConstructRecursiveCTE(initial.root, fn, label, rec.Deduplicate)
if err != nil {
return execPlan{}, err
}
Expand Down
32 changes: 32 additions & 0 deletions pkg/sql/opt/exec/execbuilder/testdata/with
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,38 @@ vectorized: true
size: 1 column, 1 row
row 0, expr 0: 1

query T
EXPLAIN (VERBOSE)
WITH RECURSIVE t(n) AS (
VALUES (1)
UNION
SELECT n+1 FROM t WHERE n < 100
)
SELECT sum(n) FROM t
----
distribution: local
vectorized: true
·
• group (scalar)
│ columns: (sum)
│ estimated row count: 1 (missing stats)
│ aggregate 0: sum(n)
└── • render
│ columns: (n)
│ estimated row count: 10 (missing stats)
│ render n: column1
└── • recursive cte
│ columns: (column1)
│ estimated row count: 10 (missing stats)
│ deduplicate
└── • values
columns: (column1)
size: 1 column, 1 row
row 0, expr 0: 1

# Tests with correlated CTEs.
query T
EXPLAIN
Expand Down
7 changes: 6 additions & 1 deletion pkg/sql/opt/exec/explain/emit.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,12 @@ func (e *emitter) emitNodeAttributes(n *Node) error {
}
e.emitSpans("spans", a.Table, a.Table.Index(cat.PrimaryIndex), params)

case recursiveCTEOp:
a := n.args.(*recursiveCTEArgs)
if e.ob.flags.Verbose && a.Deduplicate {
ob.Attrf("deduplicate", "")
}

case simpleProjectOp,
serializingProjectOp,
ordinalityOp,
Expand All @@ -809,7 +815,6 @@ func (e *emitter) emitNodeAttributes(n *Node) error {
alterTableUnsplitOp,
alterTableUnsplitAllOp,
alterTableRelocateOp,
recursiveCTEOp,
controlJobsOp,
controlSchedulesOp,
cancelQueriesOp,
Expand Down
4 changes: 3 additions & 1 deletion pkg/sql/opt/exec/factory.opt
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,13 @@ define ScanBuffer {
# function. The returned plan uses this reference with a
# ConstructScanBuffer call.
# - the plan is executed; the results are emitted and also saved in a new
# buffer for the next iteration.
# buffer for the next iteration. If Deduplicate is true, only rows that
# haven't been returned yet are emitted and saved.
define RecursiveCTE {
Initial exec.Node
Fn exec.RecursiveCTEIterationFn
Label string
Deduplicate bool
}

# ControlJobs implements PAUSE/CANCEL/RESUME JOBS.
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/opt/memo/expr_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {

case *RecursiveCTEExpr:
if !f.HasFlags(ExprFmtHideColumns) {
if t.Deduplicate {
tp.Childf("deduplicate")
}
tp.Childf("working table binding: &%d", t.WithID)
f.formatColList(e, tp, "initial columns:", t.InitialCols)
f.formatColList(e, tp, "recursive columns:", t.RecursiveCols)
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/opt/ops/relational.opt
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,10 @@ define RecursiveCTEPrivate {
# These columns are also used by the Recursive query to refer to the working
# table (see WithID).
OutCols ColList

# Deduplicate indicates if output rows should be deduplicated against all
# previous output rows (UNION variant, not UNION ALL).
Deduplicate bool
}

# FakeRel is a mock relational operator used for testing and as a dummy binding
Expand Down
Loading

0 comments on commit e030c45

Please sign in to comment.