From ec63604629e025f3b7174f62be64a0a8af401a68 Mon Sep 17 00:00:00 2001 From: Radu Berinde Date: Fri, 27 Sep 2019 09:57:43 -0400 Subject: [PATCH] opt: support recursive CTEs A recursive CTE is of the form: ``` WITH RECURSIVE cte AS ( UNION ALL ) ``` Recursive CTE evaluation (paraphrased from postgres docs): 1. Evaluate the initial query; emit the results and also save them in a "working" table. 2. So long as the working table is not empty: - evaluate the recursive query, substituting the current contents of the working table for the recursive self-reference. Emit all resulting rows, and save them into the next iteration's working table. This change adds optimizer and execution support for recursive CTEs. Some notes for the various components: - optbuilder: We build the recursive query using a `cteSource` that corresponds to the "working table". This code turned out to be tricky, in part because we want to allow non-recursive CTEs even if RECURSIVE is used (which is necessary when defining multiple CTEs, some of which are not recursive). - execution: the execution operator is somewhat similar to `applyJoinNode` (but simpler). The execbuilder provides a callback that can be used to regenerate the right-hand-side plan each time; this callback takes a reference to the working table (as a `bufferNode`). - execbuilder: to implement the callback mentioned above, we create an "inner" builder inside the callback which uses the same factory but is otherwise independent from the "outer" builder. Fixes #21085. Release note (sql change): WITH RECURSIVE is now supported. --- docs/generated/sql/bnf/delete_stmt.bnf | 2 +- docs/generated/sql/bnf/insert_stmt.bnf | 4 +- docs/generated/sql/bnf/select_stmt.bnf | 2 +- docs/generated/sql/bnf/stmt_block.bnf | 1 + docs/generated/sql/bnf/update_stmt.bnf | 2 +- docs/generated/sql/bnf/upsert_stmt.bnf | 2 +- docs/generated/sql/bnf/with_clause.bnf | 1 + pkg/sql/apply_join.go | 11 +- pkg/sql/buffer.go | 6 +- .../logictest/testdata/logic_test/with-opt | 181 +++++++ pkg/sql/opt/bench/stub_factory.go | 6 + pkg/sql/opt/exec/execbuilder/builder.go | 27 +- pkg/sql/opt/exec/execbuilder/relational.go | 66 +++ pkg/sql/opt/exec/execbuilder/testdata/with | 22 + pkg/sql/opt/exec/factory.go | 21 +- pkg/sql/opt/memo/expr_format.go | 21 +- pkg/sql/opt/memo/logical_props_builder.go | 31 ++ pkg/sql/opt/memo/statistics_builder.go | 2 +- pkg/sql/opt/ops/relational.opt | 41 ++ pkg/sql/opt/optbuilder/builder.go | 1 - pkg/sql/opt/optbuilder/delete.go | 2 +- pkg/sql/opt/optbuilder/insert.go | 2 +- pkg/sql/opt/optbuilder/project.go | 10 + pkg/sql/opt/optbuilder/scope.go | 18 +- pkg/sql/opt/optbuilder/select.go | 54 +- pkg/sql/opt/optbuilder/testdata/with | 499 ++++++++++++++++-- pkg/sql/opt/optbuilder/union.go | 24 +- pkg/sql/opt/optbuilder/update.go | 2 +- pkg/sql/opt/optbuilder/with.go | 177 ++++++- pkg/sql/opt_exec_factory.go | 11 + pkg/sql/opt_limits.go | 1 + pkg/sql/parser/parse_test.go | 6 +- pkg/sql/parser/sql.y | 5 +- pkg/sql/plan.go | 1 + pkg/sql/plan_columns.go | 2 + pkg/sql/plan_physical_props.go | 5 +- pkg/sql/recursive_cte.go | 133 +++++ pkg/sql/rowcontainer/datum_row_container.go | 4 + pkg/sql/sem/tree/pretty.go | 6 +- pkg/sql/sem/tree/with.go | 9 +- pkg/sql/walk.go | 7 + 41 files changed, 1295 insertions(+), 133 deletions(-) create mode 100644 pkg/sql/recursive_cte.go diff --git a/docs/generated/sql/bnf/delete_stmt.bnf b/docs/generated/sql/bnf/delete_stmt.bnf index 4ee2789125a4..2f2adf4d580a 100644 --- a/docs/generated/sql/bnf/delete_stmt.bnf +++ b/docs/generated/sql/bnf/delete_stmt.bnf @@ -1,2 +1,2 @@ delete_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'DELETE' 'FROM' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'DELETE' 'FROM' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) diff --git a/docs/generated/sql/bnf/insert_stmt.bnf b/docs/generated/sql/bnf/insert_stmt.bnf index 21794561ed60..f2f9a1cd139f 100644 --- a/docs/generated/sql/bnf/insert_stmt.bnf +++ b/docs/generated/sql/bnf/insert_stmt.bnf @@ -1,3 +1,3 @@ insert_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'INSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) ( 'RETURNING' ( ( target_elem ) ( ( ',' target_elem ) )* ) | 'RETURNING' 'NOTHING' | ) - | ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'INSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) on_conflict ( 'RETURNING' ( ( target_elem ) ( ( ',' target_elem ) )* ) | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'INSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) ( 'RETURNING' ( ( target_elem ) ( ( ',' target_elem ) )* ) | 'RETURNING' 'NOTHING' | ) + | ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'INSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) on_conflict ( 'RETURNING' ( ( target_elem ) ( ( ',' target_elem ) )* ) | 'RETURNING' 'NOTHING' | ) diff --git a/docs/generated/sql/bnf/select_stmt.bnf b/docs/generated/sql/bnf/select_stmt.bnf index 44d75debef4c..9a1ade86bc91 100644 --- a/docs/generated/sql/bnf/select_stmt.bnf +++ b/docs/generated/sql/bnf/select_stmt.bnf @@ -1,3 +1,3 @@ select_stmt ::= - ( simple_select locking_clause | select_clause sort_clause locking_clause | select_clause ( sort_clause | ) ( limit_clause offset_clause | offset_clause limit_clause | limit_clause | offset_clause ) locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause sort_clause locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause ( sort_clause | ) ( limit_clause offset_clause | offset_clause limit_clause | limit_clause | offset_clause ) locking_clause ) + ( simple_select locking_clause | select_clause sort_clause locking_clause | select_clause ( sort_clause | ) ( limit_clause offset_clause | offset_clause limit_clause | limit_clause | offset_clause ) locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause sort_clause locking_clause | ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) select_clause ( sort_clause | ) ( limit_clause offset_clause | offset_clause limit_clause | limit_clause | offset_clause ) locking_clause ) diff --git a/docs/generated/sql/bnf/stmt_block.bnf b/docs/generated/sql/bnf/stmt_block.bnf index 6169ee74a8e5..006c79b57203 100644 --- a/docs/generated/sql/bnf/stmt_block.bnf +++ b/docs/generated/sql/bnf/stmt_block.bnf @@ -1036,6 +1036,7 @@ opt_create_stats_options ::= with_clause ::= 'WITH' cte_list + | 'WITH' 'RECURSIVE' cte_list table_name_expr_with_index ::= table_name opt_index_flags diff --git a/docs/generated/sql/bnf/update_stmt.bnf b/docs/generated/sql/bnf/update_stmt.bnf index 9a2b313b1a0e..b2ced90abb05 100644 --- a/docs/generated/sql/bnf/update_stmt.bnf +++ b/docs/generated/sql/bnf/update_stmt.bnf @@ -1,2 +1,2 @@ update_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) opt_from_list ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) opt_from_list ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) diff --git a/docs/generated/sql/bnf/upsert_stmt.bnf b/docs/generated/sql/bnf/upsert_stmt.bnf index f936ddbc279a..0f3fc904c976 100644 --- a/docs/generated/sql/bnf/upsert_stmt.bnf +++ b/docs/generated/sql/bnf/upsert_stmt.bnf @@ -1,2 +1,2 @@ upsert_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) | 'WITH' 'RECURSIVE' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPSERT' 'INTO' ( table_name | table_name 'AS' table_alias_name ) ( select_stmt | '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' select_stmt | 'DEFAULT' 'VALUES' ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) diff --git a/docs/generated/sql/bnf/with_clause.bnf b/docs/generated/sql/bnf/with_clause.bnf index c4058a1f57a9..c7e9930aa14d 100644 --- a/docs/generated/sql/bnf/with_clause.bnf +++ b/docs/generated/sql/bnf/with_clause.bnf @@ -1,2 +1,3 @@ with_clause ::= 'WITH' ( ( ( table_alias_name ( '(' ( ( name ) ( ( ',' name ) )* ) ')' | ) 'AS' '(' preparable_stmt ')' ) ) ( ( ',' ( table_alias_name ( '(' ( ( name ) ( ( ',' name ) )* ) ')' | ) 'AS' '(' preparable_stmt ')' ) ) )* ) ( insert_stmt | update_stmt | delete_stmt | upsert_stmt | select_stmt ) + | 'WITH' 'RECURSIVE' ( ( ( table_alias_name ( '(' ( ( name ) ( ( ',' name ) )* ) ')' | ) 'AS' '(' preparable_stmt ')' ) ) ( ( ',' ( table_alias_name ( '(' ( ( name ) ( ( ',' name ) )* ) ')' | ) 'AS' '(' preparable_stmt ')' ) ) )* ) ( insert_stmt | update_stmt | delete_stmt | upsert_stmt | select_stmt ) diff --git a/pkg/sql/apply_join.go b/pkg/sql/apply_join.go index 5bebbc35feed..2630094ac04e 100644 --- a/pkg/sql/apply_join.go +++ b/pkg/sql/apply_join.go @@ -294,7 +294,15 @@ func (a *applyJoinNode) Next(params runParams) (bool, error) { func (a *applyJoinNode) runRightSidePlan(params runParams, plan *planTop) error { a.run.curRightRow = 0 a.run.rightRows.Clear(params.ctx) - rowResultWriter := NewRowResultWriter(a.run.rightRows) + return runPlanInsidePlan(params, plan, a.run.rightRows) +} + +// runPlanInsidePlan is used to run a plan and gather the results in a row +// container, as part of the execution of an "outer" plan. +func runPlanInsidePlan( + params runParams, plan *planTop, rowContainer *rowcontainer.RowContainer, +) error { + rowResultWriter := NewRowResultWriter(rowContainer) recv := MakeDistSQLReceiver( params.ctx, rowResultWriter, tree.Rows, params.extendedEvalCtx.ExecCfg.RangeDescriptorCache, @@ -339,7 +347,6 @@ func (a *applyJoinNode) runRightSidePlan(params runParams, plan *planTop) error return recv.commErr } return rowResultWriter.err - } func (a *applyJoinNode) Values() tree.Datums { diff --git a/pkg/sql/buffer.go b/pkg/sql/buffer.go index 93f044cae0b6..630e1f431372 100644 --- a/pkg/sql/buffer.go +++ b/pkg/sql/buffer.go @@ -67,11 +67,7 @@ func (n *bufferNode) Values() tree.Datums { func (n *bufferNode) Close(ctx context.Context) { n.plan.Close(ctx) - // It's valid to be Closed without startExec having been called, in which - // case n.bufferedRows will be nil. - if n.bufferedRows != nil { - n.bufferedRows.Close(ctx) - } + n.bufferedRows.Close(ctx) } // scanBufferNode behaves like an iterator into the bufferNode it is diff --git a/pkg/sql/logictest/testdata/logic_test/with-opt b/pkg/sql/logictest/testdata/logic_test/with-opt index 2082ebf88a25..b1728dd3270b 100644 --- a/pkg/sql/logictest/testdata/logic_test/with-opt +++ b/pkg/sql/logictest/testdata/logic_test/with-opt @@ -155,3 +155,184 @@ EXECUTE z6(3) ---- 1 4 2 5 + +# Recursive CTE example from postgres docs. +query T +WITH RECURSIVE t(n) AS ( + VALUES (1) + UNION ALL + SELECT n+1 FROM t WHERE n < 100 +) +SELECT sum(n) FROM t +---- +5050 + +# Test where initial query has duplicate columns. +query II +WITH RECURSIVE cte(a, b) AS ( + SELECT 0, 0 + UNION ALL + SELECT a+1, b+10 FROM cte WHERE a < 5 +) SELECT * FROM cte; +---- +0 0 +1 10 +2 20 +3 30 +4 40 +5 50 + +# Test where recursive query has duplicate columns. +query II +WITH RECURSIVE cte(a, b) AS ( + SELECT 0, 1 + UNION ALL + SELECT a+1, a+1 FROM cte WHERE a < 5 +) SELECT * FROM cte; +---- +0 1 +1 1 +2 2 +3 3 +4 4 +5 5 + +# Recursive CTE examples adapted from +# https://malisper.me/generating-fractals-with-postgres-escape-time-fractals. +query T +WITH RECURSIVE points AS ( + SELECT i::float * 0.05 AS r, j::float * 0.05 AS c + FROM generate_series(-20, 20) AS a (i), generate_series(-40, 20) AS b (j) +), iterations AS ( + SELECT r, + c, + 0.0::float AS zr, + 0.0::float AS zc, + 0 AS iteration + FROM points + UNION ALL + SELECT r, + c, + zr*zr - zc*zc + c AS zr, + 2*zr*zc + r AS zc, + iteration+1 AS iteration + FROM iterations WHERE zr*zr + zc*zc < 4 AND iteration < 20 +), final_iteration AS ( + SELECT * FROM iterations WHERE iteration = 20 +), marked_points AS ( + SELECT r, + c, + (CASE WHEN EXISTS (SELECT 1 FROM final_iteration i WHERE p.r = i.r AND p.c = i.c) + THEN 'oo' ELSE '··' END) AS marker FROM points p +), lines AS ( + SELECT r, string_agg(marker, '' ORDER BY c ASC) AS r_text + FROM marked_points + GROUP BY r +) SELECT string_agg(r_text, E'\n' ORDER BY r DESC) FROM lines +---- +················································································oo········································ +············································································oo············································ +··········································································oooo············································ +······································································oo··oooo············································ +········································································oooooooo·········································· +······································································oooooooooooo········································ +········································································oooooooo·········································· +··························································oo····oooooooooooooooooooo··oo·································· +··························································oooo··oooooooooooooooooooooooo·································· +··························································oooooooooooooooooooooooooooooooooooooo·························· +··························································oooooooooooooooooooooooooooooooooooooo·························· +····················································oooooooooooooooooooooooooooooooooooooooooo···························· +······················································oooooooooooooooooooooooooooooooooooooooo···························· +····················································oooooooooooooooooooooooooooooooooooooooooooooo························ +··································oo····oo··········oooooooooooooooooooooooooooooooooooooooooooo·························· +··································oooooooooooo······oooooooooooooooooooooooooooooooooooooooooooo·························· +··································oooooooooooooo····oooooooooooooooooooooooooooooooooooooooooooooo························ +································oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo························ +······························oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·························· +··························oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo···························· +··oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo······························ +··························oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo···························· +······························oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·························· +································oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo························ +··································oooooooooooooo····oooooooooooooooooooooooooooooooooooooooooooooo························ +··································oooooooooooo······oooooooooooooooooooooooooooooooooooooooooooo·························· +··································oo····oo··········oooooooooooooooooooooooooooooooooooooooooooo·························· +····················································oooooooooooooooooooooooooooooooooooooooooooooo························ +······················································oooooooooooooooooooooooooooooooooooooooo···························· +····················································oooooooooooooooooooooooooooooooooooooooooo···························· +··························································oooooooooooooooooooooooooooooooooooooo·························· +··························································oooooooooooooooooooooooooooooooooooooo·························· +··························································oooo··oooooooooooooooooooooooo·································· +··························································oo····oooooooooooooooooooo··oo·································· +········································································oooooooo·········································· +······································································oooooooooooo········································ +········································································oooooooo·········································· +······································································oo··oooo············································ +··········································································oooo············································ +············································································oo············································ +················································································oo········································ + +query T +WITH RECURSIVE points AS ( + SELECT i::float * 0.05 AS r, j::float * 0.05 AS c + FROM generate_series(-20, 20) AS a (i), generate_series(-30, 30) AS b (j) +), iterations AS ( + SELECT r, c, c::float AS zr, r::float AS zc, 0 AS iteration FROM points + UNION ALL + SELECT r, c, zr*zr - zc*zc + 1 - 1.61803398875 AS zr, 2*zr*zc AS zc, iteration+1 AS iteration + FROM iterations WHERE zr*zr + zc*zc < 4 AND iteration < 20 +), final_iteration AS ( + SELECT * FROM iterations WHERE iteration = 20 +), marked_points AS ( + SELECT r, c, (CASE WHEN EXISTS (SELECT 1 FROM final_iteration i WHERE p.r = i.r AND p.c = i.c) + THEN 'oo' + ELSE '··' + END) AS marker + FROM points p +), rows AS ( + SELECT r, string_agg(marker, '' ORDER BY c ASC) AS r_text + FROM marked_points + GROUP BY r +) SELECT string_agg(r_text, E'\n' ORDER BY r DESC) FROM rows +---- +·························································································································· +·························································································································· +····························································oo···························································· +····························································oo···························································· +························································oooooooooo························································ +························································oooooooooo························································ +························································oooooooooo························································ +··············································oo··oooooooooooooooooooooo··oo·············································· +··············································oooooooooooooooooooooooooooooo·············································· +············································oooooooooooooooooooooooooooooooooo············································ +··········································oooooooooooooooooooooooooooooooooooooo·········································· +························oooo····oo········oooooooooooooooooooooooooooooooooooooo········oo····oooo························ +························oooooooooooooo····oooooooooooooooooooooooooooooooooooooo····oooooooooooooo························ +······················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo······················ +····················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo···················· +··················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·················· +··················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·················· +··········oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·········· +··········oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·········· +······oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo······ +····oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo···· +······oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo······ +··········oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·········· +··········oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·········· +··················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·················· +··················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo·················· +····················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo···················· +······················oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo······················ +························oooooooooooooo····oooooooooooooooooooooooooooooooooooooo····oooooooooooooo························ +························oooo····oo········oooooooooooooooooooooooooooooooooooooo········oo····oooo························ +··········································oooooooooooooooooooooooooooooooooooooo·········································· +············································oooooooooooooooooooooooooooooooooo············································ +··············································oooooooooooooooooooooooooooooo·············································· +··············································oo··oooooooooooooooooooooo··oo·············································· +························································oooooooooo························································ +························································oooooooooo························································ +························································oooooooooo························································ +····························································oo···························································· +····························································oo···························································· +·························································································································· +·························································································································· diff --git a/pkg/sql/opt/bench/stub_factory.go b/pkg/sql/opt/bench/stub_factory.go index 0edbf4c6a2cc..0e83c349ef16 100644 --- a/pkg/sql/opt/bench/stub_factory.go +++ b/pkg/sql/opt/bench/stub_factory.go @@ -339,6 +339,12 @@ func (f *stubFactory) ConstructScanBuffer(ref exec.Node, label string) (exec.Nod return struct{}{}, nil } +func (f *stubFactory) ConstructRecursiveCTE( + initial exec.Node, fn exec.RecursiveCTEIterationFn, label string, +) (exec.Node, error) { + return struct{}{}, nil +} + func (f *stubFactory) ConstructControlJobs( command tree.JobCommand, input exec.Node, ) (exec.Node, error) { diff --git a/pkg/sql/opt/exec/execbuilder/builder.go b/pkg/sql/opt/exec/execbuilder/builder.go index 744286c40d47..c44453954e91 100644 --- a/pkg/sql/opt/exec/execbuilder/builder.go +++ b/pkg/sql/opt/exec/execbuilder/builder.go @@ -93,6 +93,14 @@ func (b *Builder) DisableTelemetry() { // Build constructs the execution node tree and returns its root node if no // error occurred. func (b *Builder) Build() (_ exec.Plan, err error) { + plan, err := b.build(b.e) + if err != nil { + return nil, err + } + return b.factory.ConstructPlan(plan.root, b.subqueries, b.postqueries) +} + +func (b *Builder) build(e opt.Expr) (_ execPlan, err error) { defer func() { if r := recover(); r != nil { // This code allows us to propagate errors without adding lots of checks @@ -107,27 +115,16 @@ func (b *Builder) Build() (_ exec.Plan, err error) { } }() - root, err := b.build(b.e) - if err != nil { - return nil, err - } - return b.factory.ConstructPlan(root, b.subqueries, b.postqueries) -} - -func (b *Builder) build(e opt.Expr) (exec.Node, error) { rel, ok := e.(memo.RelExpr) if !ok { - return nil, errors.AssertionFailedf("building execution for non-relational operator %s", log.Safe(e.Op())) + return execPlan{}, errors.AssertionFailedf( + "building execution for non-relational operator %s", log.Safe(e.Op()), + ) } b.autoCommit = b.canAutoCommit(rel) - plan, err := b.buildRelational(rel) - if err != nil { - return nil, err - } - - return plan.root, nil + return b.buildRelational(rel) } // BuildScalar converts a scalar expression to a TypedExpr. Variables are mapped diff --git a/pkg/sql/opt/exec/execbuilder/relational.go b/pkg/sql/opt/exec/execbuilder/relational.go index 5c8071f584fb..67db482e46c0 100644 --- a/pkg/sql/opt/exec/execbuilder/relational.go +++ b/pkg/sql/opt/exec/execbuilder/relational.go @@ -258,6 +258,9 @@ func (b *Builder) buildRelational(e memo.RelExpr) (execPlan, error) { case *memo.WithScanExpr: ep, err = b.buildWithScan(t) + case *memo.RecursiveCTEExpr: + ep, err = b.buildRecursiveCTE(t) + case *memo.ExplainExpr: ep, err = b.buildExplain(t) @@ -1424,6 +1427,69 @@ func (b *Builder) buildWith(with *memo.WithExpr) (execPlan, error) { return b.buildRelational(with.Main) } +func (b *Builder) buildRecursiveCTE(rec *memo.RecursiveCTEExpr) (execPlan, error) { + initial, err := b.buildRelational(rec.Initial) + if err != nil { + return execPlan{}, err + } + + // Make sure we have the columns in the correct order. + initial, err = b.ensureColumns(initial, rec.InitialCols, nil /* colNames */, nil /* ordering */) + if err != nil { + return execPlan{}, err + } + + // Renumber the columns so they match the columns expected by the recursive + // query. + initial.outputCols = util.FastIntMap{} + for i, col := range rec.OutCols { + initial.outputCols.Set(int(col), i) + } + + // To implement exec.RecursiveCTEIterationFn, we create a special Builder. + + innerBldTemplate := &Builder{ + factory: b.factory, + mem: b.mem, + catalog: b.catalog, + evalCtx: b.evalCtx, + // If the recursive query itself contains CTEs, building it in the function + // below will add to withExprs. Cap the slice to force reallocation on any + // appends, so that they don't overwrite overwrite later appends by our + // original builder. + withExprs: b.withExprs[:len(b.withExprs):len(b.withExprs)], + } + + fn := func(bufferRef exec.Node) (exec.Plan, error) { + // Use a separate builder each time. + innerBld := *innerBldTemplate + innerBld.addBuiltWithExpr(rec.WithID, initial.outputCols, bufferRef) + plan, err := innerBld.build(rec.Recursive) + if err != nil { + return nil, err + } + // Ensure columns are output in the same order. + plan, err = innerBld.ensureColumns( + plan, rec.RecursiveCols, nil /* colNames */, nil, /* ordering */ + ) + if err != nil { + return nil, err + } + return innerBld.factory.ConstructPlan(plan.root, innerBld.subqueries, innerBld.postqueries) + } + + label := fmt.Sprintf("working buffer (%s)", rec.Name) + var ep execPlan + ep.root, err = b.factory.ConstructRecursiveCTE(initial.root, fn, label) + if err != nil { + return execPlan{}, err + } + for i, col := range rec.OutCols { + ep.outputCols.Set(int(col), i) + } + return ep, nil +} + func (b *Builder) buildWithScan(withScan *memo.WithScanExpr) (execPlan, error) { id := withScan.ID var e *builtWithExpr diff --git a/pkg/sql/opt/exec/execbuilder/testdata/with b/pkg/sql/opt/exec/execbuilder/testdata/with index e220ca440b3a..5d49af8ba440 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/with +++ b/pkg/sql/opt/exec/execbuilder/testdata/with @@ -92,3 +92,25 @@ hash-join · · (col) · └── scan · · (col) · · table table39010@primary · · · spans ALL · · + +query TTTTT +EXPLAIN (VERBOSE) + WITH RECURSIVE t(n) AS ( + VALUES (1) + UNION ALL + SELECT n+1 FROM t WHERE n < 100 + ) + SELECT sum(n) FROM t +---- +· distributed false · · +· vectorized false · · +group · · (sum) · + │ aggregate 0 sum(n) · · + │ scalar · · · + └── render · · (n) · + │ render 0 column1 · · + └── recursive cte node · · (column1) · + │ label working buffer (t) · · + └── values · · (column1) · +· size 1 column, 1 row · · +· row 0, expr 0 1 · · diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index aa89cd4cb2e1..e436447b50a5 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -473,12 +473,24 @@ type Factory interface { // ConstructBuffer constructs a node whose input can be referenced from // elsewhere in the query. - ConstructBuffer(value Node, label string) (Node, error) + ConstructBuffer(input Node, label string) (Node, error) // ConstructScanBuffer constructs a node which refers to a node constructed by - // ConstructBuffer. + // ConstructBuffer or passed to RecursiveCTEIterationFn. ConstructScanBuffer(ref Node, label string) (Node, error) + // ConstructRecursiveCTE constructs a node that executes a recursive CTE: + // * the initial plan is run first; the results are emitted and also saved + // in a buffer. + // * so long as the last buffer is not empty: + // - the RecursiveCTEIterationFn is used to create a plan for the + // recursive side; a reference to the last buffer is passed to this + // function. The returned plan uses this reference with a + // ConstructScanBuffer call. + // - the plan is executed; the results are emitted and also saved in a new + // buffer for the next iteration. + ConstructRecursiveCTE(initial Node, fn RecursiveCTEIterationFn, label string) (Node, error) + // ConstructControlJobs creates a node that implements PAUSE/CANCEL/RESUME // JOBS. ConstructControlJobs(command tree.JobCommand, input Node) (Node, error) @@ -613,3 +625,8 @@ type KVOption struct { // If there is no value, Value is DNull. Value tree.TypedExpr } + +// RecursiveCTEIterationFn creates a plan for an iteration of WITH RECURSIVE, +// given the result of the last iteration (as a Buffer that can be used with +// ConstructScanBuffer). +type RecursiveCTEIterationFn func(bufferRef Node) (Plan, error) diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index 1991cbb0bf00..04ca14d08041 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -208,17 +208,15 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) { } case *WithExpr: - w := e.(*WithExpr) - fmt.Fprintf(f.Buffer, "%v &%d", e.Op(), w.ID) - if w.Name != "" { - fmt.Fprintf(f.Buffer, " (%s)", w.Name) + fmt.Fprintf(f.Buffer, "%v &%d", e.Op(), t.ID) + if t.Name != "" { + fmt.Fprintf(f.Buffer, " (%s)", t.Name) } case *WithScanExpr: - ws := e.(*WithScanExpr) - fmt.Fprintf(f.Buffer, "%v &%d", e.Op(), ws.ID) - if ws.Name != "" { - fmt.Fprintf(f.Buffer, " (%s)", ws.Name) + fmt.Fprintf(f.Buffer, "%v &%d", e.Op(), t.ID) + if t.Name != "" { + fmt.Fprintf(f.Buffer, " (%s)", t.Name) } default: @@ -469,6 +467,13 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) { tp.Childf("mode: %s", m) } + case *RecursiveCTEExpr: + if !f.HasFlags(ExprFmtHideColumns) { + tp.Childf("working table binding: &%d", t.WithID) + f.formatColList(e, tp, "initial columns:", t.InitialCols) + f.formatColList(e, tp, "recursive columns:", t.RecursiveCols) + } + default: if opt.IsJoinOp(t) { p := t.Private().(*JoinPrivate) diff --git a/pkg/sql/opt/memo/logical_props_builder.go b/pkg/sql/opt/memo/logical_props_builder.go index ced500bc7c41..3f84ffe734df 100644 --- a/pkg/sql/opt/memo/logical_props_builder.go +++ b/pkg/sql/opt/memo/logical_props_builder.go @@ -792,6 +792,37 @@ func (b *logicalPropsBuilder) buildWithScanProps(withScan *WithScanExpr, rel *pr } } +func (b *logicalPropsBuilder) buildRecursiveCTEProps(rec *RecursiveCTEExpr, rel *props.Relational) { + BuildSharedProps(b.mem, rec, &rel.Shared) + + // Output Columns + // -------------- + rel.OutputCols = rec.OutCols.ToSet() + + // Not Null Columns + // ---------------- + // All columns are assumed to be nullable. + + // Outer Columns + // ------------- + // No outer columns. + + // Functional Dependencies + // ----------------------- + // No known FDs. + + // Cardinality + // ----------- + // At least the cardinality of the initial buffer. + rel.Cardinality = props.AnyCardinality.AtLeast(rec.Initial.Relational().Cardinality) + + // Statistics + // ---------- + if !b.disableStats { + b.sb.buildUnknown(rel) + } +} + func (b *logicalPropsBuilder) buildExplainProps(explain *ExplainExpr, rel *props.Relational) { b.buildBasicProps(explain, explain.ColList, rel) } diff --git a/pkg/sql/opt/memo/statistics_builder.go b/pkg/sql/opt/memo/statistics_builder.go index fb1c2d735eef..88264e93325c 100644 --- a/pkg/sql/opt/memo/statistics_builder.go +++ b/pkg/sql/opt/memo/statistics_builder.go @@ -379,7 +379,7 @@ func (sb *statisticsBuilder) colStat(colSet opt.ColSet, e RelExpr) *props.Column return sb.colStatSequenceSelect(colSet, e.(*SequenceSelectExpr)) case opt.ExplainOp, opt.ShowTraceForSessionOp, - opt.OpaqueRelOp, opt.OpaqueMutationOp, opt.OpaqueDDLOp: + opt.OpaqueRelOp, opt.OpaqueMutationOp, opt.OpaqueDDLOp, opt.RecursiveCTEOp: return sb.colStatUnknown(colSet, e.Relational()) case opt.WithOp: diff --git a/pkg/sql/opt/ops/relational.opt b/pkg/sql/opt/ops/relational.opt index b0c943f50417..9b1e6a2db1fb 100644 --- a/pkg/sql/opt/ops/relational.opt +++ b/pkg/sql/opt/ops/relational.opt @@ -863,6 +863,47 @@ define WithScanPrivate { OutCols ColList } +# RecursiveCTE implements the logic of a recursive CTE: +# * the Initial query is evaluated; the results are emitted and also saved into +# a "working table". +# * so long as the working table is not empty: +# - the Recursive query (which refers to the working table using a specific +# WithID) is evaluated; the results are emitted and also saved into a new +# "working table" for the next iteration. +[Relational] +define RecursiveCTE { + Initial RelExpr + Recursive RelExpr + + _ RecursiveCTEPrivate +} + +[Private] +define RecursiveCTEPrivate { + # Name is used to identify the CTE being referenced for debugging purposes. + Name string + + # WithID is the ID through which the Recursive expression refers to the + # current working table. + WithID WithID + + # InitialCols are the columns produced by the initial expression. + InitialCols ColList + + # RecursiveCols are the columns produced by the recursive expression, that + # map 1-1 to InitialCols. + RecursiveCols ColList + + # OutCols are the columns produced by the RecursiveCTE operator; they map + # 1-1 to InitialCols and to RecursiveCols. Similar to Union, we don't want + # to reuse column IDs from one side because the columns contain values from + # both sides. + # + # These columns are also used by the Recursive query to refer to the working + # table (see WithID). + OutCols ColList +} + # FakeRel is a mock relational operator used for testing; its logical properties # are pre-determined and stored in the private. It can be used as the child of # an operator for which we are calculating properties or statistics. diff --git a/pkg/sql/opt/optbuilder/builder.go b/pkg/sql/opt/optbuilder/builder.go index a7b79fc8f8af..5f0d67973f86 100644 --- a/pkg/sql/opt/optbuilder/builder.go +++ b/pkg/sql/opt/optbuilder/builder.go @@ -91,7 +91,6 @@ type Builder struct { evalCtx *tree.EvalContext catalog cat.Catalog scopeAlloc []scope - ctes []cteSource // If set, the planner will skip checking for the SELECT privilege when // resolving data sources (tables, views, etc). This is used when compiling diff --git a/pkg/sql/opt/optbuilder/delete.go b/pkg/sql/opt/optbuilder/delete.go index 8c4e9cc359f1..9f1fd448ff0a 100644 --- a/pkg/sql/opt/optbuilder/delete.go +++ b/pkg/sql/opt/optbuilder/delete.go @@ -41,7 +41,7 @@ func (b *Builder) buildDelete(del *tree.Delete, inScope *scope) (outScope *scope var ctes []cteSource if del.With != nil { - inScope, ctes = b.buildCTE(del.With.CTEList, inScope) + inScope, ctes = b.buildCTEs(del.With, inScope) } // DELETE FROM xx AS yy - we want to know about xx (tn) because diff --git a/pkg/sql/opt/optbuilder/insert.go b/pkg/sql/opt/optbuilder/insert.go index 5352a20efdf9..9b36d9fd5173 100644 --- a/pkg/sql/opt/optbuilder/insert.go +++ b/pkg/sql/opt/optbuilder/insert.go @@ -156,7 +156,7 @@ func init() { func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope) { var ctes []cteSource if ins.With != nil { - inScope, ctes = b.buildCTE(ins.With.CTEList, inScope) + inScope, ctes = b.buildCTEs(ins.With, inScope) } // INSERT INTO xx AS yy - we want to know about xx (tn) because diff --git a/pkg/sql/opt/optbuilder/project.go b/pkg/sql/opt/optbuilder/project.go index 8228fcec7cb8..82aed91fb3f2 100644 --- a/pkg/sql/opt/optbuilder/project.go +++ b/pkg/sql/opt/optbuilder/project.go @@ -58,6 +58,16 @@ func (b *Builder) constructProject(input memo.RelExpr, cols []scopeColumn) memo. return b.factory.ConstructProject(input, projections, passthrough) } +// dropOrderingAndExtraCols removes the ordering in the scope and projects away +// any extra columns. +func (b *Builder) dropOrderingAndExtraCols(s *scope) { + s.ordering = nil + if len(s.extraCols) > 0 { + s.extraCols = nil + s.expr = b.constructProject(s.expr, s.cols) + } +} + // analyzeProjectionList analyzes the given list of SELECT clause expressions, // and adds the resulting aliases and typed expressions to outScope. See the // header comment for analyzeSelectList. diff --git a/pkg/sql/opt/optbuilder/scope.go b/pkg/sql/opt/optbuilder/scope.go index fe975d5ca04b..b5e1a8695305 100644 --- a/pkg/sql/opt/optbuilder/scope.go +++ b/pkg/sql/opt/optbuilder/scope.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" + "github.com/cockroachdb/cockroach/pkg/sql/opt/props" "github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -94,11 +95,15 @@ type scope struct { // cteSource represents a CTE in the given query. type cteSource struct { + id opt.WithID name tree.AliasClause cols physical.Presentation originalExpr tree.Statement + bindingProps *props.Relational expr memo.RelExpr - id opt.WithID + // If set, this function is called when a CTE is referenced. It can throw an + // error. + onRef func() } // groupByStrSet is a set of stringified GROUP BY expressions that map to the @@ -230,6 +235,14 @@ func (s *scope) getColumn(col opt.ColumnID) *scopeColumn { return nil } +func (s *scope) makeColumnTypes() []*types.T { + res := make([]*types.T, len(s.cols)) + for i := range res { + res[i] = s.cols[i].typ + } + return res +} + // makeOrderingChoice returns an OrderingChoice that corresponds to s.ordering. func (s *scope) makeOrderingChoice() physical.OrderingChoice { var oc physical.OrderingChoice @@ -292,6 +305,9 @@ func (s *scope) resolveCTE(name *tree.TableName) *cteSource { seenCTEs = true } if cte, ok := s.ctes[nameStr]; ok { + if cte.onRef != nil { + cte.onRef() + } return cte } } diff --git a/pkg/sql/opt/optbuilder/select.go b/pkg/sql/opt/optbuilder/select.go index 60f8801a6cd1..76406bcdb31f 100644 --- a/pkg/sql/opt/optbuilder/select.go +++ b/pkg/sql/opt/optbuilder/select.go @@ -91,7 +91,7 @@ func (b *Builder) buildDataSource( Name: string(cte.name.Alias), InCols: inCols, OutCols: outCols, - BindingProps: cte.expr.Relational(), + BindingProps: cte.bindingProps, }) return outScope } @@ -571,54 +571,44 @@ func (b *Builder) buildWithOrdinality(colName string, inScope *scope) (outScope return inScope } -func (b *Builder) buildCTE( - ctes []*tree.CTE, inScope *scope, +func (b *Builder) buildCTEs( + with *tree.With, inScope *scope, ) (outScope *scope, addedCTEs []cteSource) { outScope = inScope.push() - start := len(b.ctes) - + addedCTEs = make([]cteSource, len(with.CTEList)) outScope.ctes = make(map[string]*cteSource) - for i := range ctes { - cteScope := b.buildStmt(ctes[i].Stmt, nil /* desiredTypes */, outScope) - cols := b.getCTECols(cteScope, ctes[i].Name) - name := ctes[i].Name.Alias + for i, cte := range with.CTEList { + cteExpr, cteCols := b.buildCTE(cte, outScope, with.Recursive) // TODO(justin): lift this restriction when possible. WITH should be hoistable. if b.subquery != nil && !b.subquery.outerCols.Empty() { panic(pgerror.Newf(pgcode.FeatureNotSupported, "CTEs may not be correlated")) } - if _, ok := outScope.ctes[name.String()]; ok { + aliasStr := cte.Name.Alias.String() + if _, ok := outScope.ctes[aliasStr]; ok { panic(pgerror.Newf( - pgcode.DuplicateAlias, - "WITH query name %s specified more than once", ctes[i].Name.Alias), - ) + pgcode.DuplicateAlias, "WITH query name %s specified more than once", aliasStr, + )) } - cteScope.removeHiddenCols() - projectionsScope := cteScope.replace() - projectionsScope.appendColumnsFromScope(cteScope) - b.constructProjectForScope(cteScope, projectionsScope) - - cteScope = projectionsScope - id := b.factory.Memo().NextWithID() - b.ctes = append(b.ctes, cteSource{ - name: ctes[i].Name, - cols: cols, - originalExpr: ctes[i].Stmt, - expr: cteScope.expr, + addedCTEs[i] = cteSource{ + name: cte.Name, + cols: cteCols, + originalExpr: cte.Stmt, + expr: cteExpr, + bindingProps: cteExpr.Relational(), id: id, - }) - cte := &b.ctes[len(b.ctes)-1] - outScope.ctes[ctes[i].Name.Alias.String()] = cte + } + outScope.ctes[aliasStr] = &addedCTEs[i] } telemetry.Inc(sqltelemetry.CteUseCounter) - return outScope, b.ctes[start:] + return outScope, addedCTEs } // wrapWithCTEs adds With expressions on top of an expression. @@ -656,7 +646,7 @@ func (b *Builder) buildSelectStmt( return b.buildSelectClause(stmt, nil /* orderBy */, desiredTypes, inScope) case *tree.UnionClause: - return b.buildUnion(stmt, desiredTypes, inScope) + return b.buildUnionClause(stmt, desiredTypes, inScope) case *tree.ValuesClause: return b.buildValuesClause(stmt, desiredTypes, inScope) @@ -725,7 +715,7 @@ func (b *Builder) buildSelect( var ctes []cteSource if with != nil { - inScope, ctes = b.buildCTE(with.CTEList, inScope) + inScope, ctes = b.buildCTEs(with, inScope) } // NB: The case statements are sorted lexicographically. @@ -734,7 +724,7 @@ func (b *Builder) buildSelect( outScope = b.buildSelectClause(t, orderBy, desiredTypes, inScope) case *tree.UnionClause: - outScope = b.buildUnion(t, desiredTypes, inScope) + outScope = b.buildUnionClause(t, desiredTypes, inScope) case *tree.ValuesClause: outScope = b.buildValuesClause(t, desiredTypes, inScope) diff --git a/pkg/sql/opt/optbuilder/testdata/with b/pkg/sql/opt/optbuilder/testdata/with index add3b69bfe9e..d3cc5b2e3cb4 100644 --- a/pkg/sql/opt/optbuilder/testdata/with +++ b/pkg/sql/opt/optbuilder/testdata/with @@ -511,44 +511,36 @@ error (42P01): no data source matches prefix: x build WITH t(x) AS (WITH t(x) AS (SELECT 1) SELECT x * 10 FROM t) SELECT x + 2 FROM t ---- -with &1 (t) +with &2 (t) ├── columns: "?column?":5(int) - ├── project - │ ├── columns: "?column?":1(int!null) - │ ├── values - │ │ └── tuple [type=tuple] - │ └── projections - │ └── const: 1 [type=int] - └── with &2 (t) + ├── with &1 (t) + │ ├── columns: "?column?":3(int) + │ ├── project + │ │ ├── columns: "?column?":1(int!null) + │ │ ├── values + │ │ │ └── tuple [type=tuple] + │ │ └── projections + │ │ └── const: 1 [type=int] + │ └── project + │ ├── columns: "?column?":3(int) + │ ├── with-scan &1 (t) + │ │ ├── columns: x:2(int!null) + │ │ └── mapping: + │ │ └── "?column?":1(int) => x:2(int) + │ └── projections + │ └── mult [type=int] + │ ├── variable: x [type=int] + │ └── const: 10 [type=int] + └── project ├── columns: "?column?":5(int) - ├── with &1 (t) - │ ├── columns: "?column?":3(int) - │ ├── project - │ │ ├── columns: "?column?":1(int!null) - │ │ ├── values - │ │ │ └── tuple [type=tuple] - │ │ └── projections - │ │ └── const: 1 [type=int] - │ └── project - │ ├── columns: "?column?":3(int) - │ ├── with-scan &1 (t) - │ │ ├── columns: x:2(int!null) - │ │ └── mapping: - │ │ └── "?column?":1(int) => x:2(int) - │ └── projections - │ └── mult [type=int] - │ ├── variable: x [type=int] - │ └── const: 10 [type=int] - └── project - ├── columns: "?column?":5(int) - ├── with-scan &2 (t) - │ ├── columns: x:4(int) - │ └── mapping: - │ └── "?column?":3(int) => x:4(int) - └── projections - └── plus [type=int] - ├── variable: x [type=int] - └── const: 2 [type=int] + ├── with-scan &2 (t) + │ ├── columns: x:4(int) + │ └── mapping: + │ └── "?column?":3(int) => x:4(int) + └── projections + └── plus [type=int] + ├── variable: x [type=int] + └── const: 2 [type=int] build WITH one AS (SELECT a AS u FROM x), @@ -788,3 +780,438 @@ with &1 (cte) ├── description:6(string) => description:11(string) ├── columns:7(string) => columns:12(string) └── ordering:8(string) => ordering:13(string) + +# WITH RECURSIVE examples from postgres docs. + +build +WITH RECURSIVE t(n) AS ( + VALUES (1) + UNION ALL + SELECT n+1 FROM t WHERE n < 100 +) +SELECT sum(n) FROM t +---- +with &2 (t) + ├── columns: sum:6(decimal) + ├── recursive-c-t-e + │ ├── columns: n:2(int) + │ ├── working table binding: &1 + │ ├── initial columns: column1:1(int) + │ ├── recursive columns: "?column?":4(int) + │ ├── values + │ │ ├── columns: column1:1(int!null) + │ │ └── tuple [type=tuple{int}] + │ │ └── const: 1 [type=int] + │ └── project + │ ├── columns: "?column?":4(int) + │ ├── select + │ │ ├── columns: n:3(int!null) + │ │ ├── with-scan &1 (t) + │ │ │ ├── columns: n:3(int) + │ │ │ └── mapping: + │ │ │ └── n:2(int) => n:3(int) + │ │ └── filters + │ │ └── lt [type=bool] + │ │ ├── variable: n [type=int] + │ │ └── const: 100 [type=int] + │ └── projections + │ └── plus [type=int] + │ ├── variable: n [type=int] + │ └── const: 1 [type=int] + └── scalar-group-by + ├── columns: sum:6(decimal) + ├── with-scan &2 (t) + │ ├── columns: n:5(int) + │ └── mapping: + │ └── n:2(int) => n:5(int) + └── aggregations + └── sum [type=decimal] + └── variable: n [type=int] + +exec-ddl +CREATE TABLE parts (part STRING, sub_part STRING, quantity INT) +---- + +build +WITH RECURSIVE included_parts(sub_part, part, quantity) AS ( + SELECT sub_part, part, quantity FROM parts WHERE part = 'our_product' + UNION ALL + SELECT p.sub_part, p.part, p.quantity + FROM included_parts AS pr, parts AS p + WHERE p.part = pr.sub_part +) +SELECT sub_part, sum(quantity) as total_quantity +FROM included_parts +GROUP BY sub_part +---- +with &2 (included_parts) + ├── columns: sub_part:15(string) total_quantity:18(decimal) + ├── recursive-c-t-e + │ ├── columns: sub_part:5(string) part:6(string) quantity:7(int) + │ ├── working table binding: &1 + │ ├── initial columns: parts.sub_part:2(string) parts.part:1(string) parts.quantity:3(int) + │ ├── recursive columns: p.sub_part:12(string) p.part:11(string) p.quantity:13(int) + │ ├── project + │ │ ├── columns: parts.part:1(string!null) parts.sub_part:2(string) parts.quantity:3(int) + │ │ └── select + │ │ ├── columns: parts.part:1(string!null) parts.sub_part:2(string) parts.quantity:3(int) parts.rowid:4(int!null) + │ │ ├── scan parts + │ │ │ └── columns: parts.part:1(string) parts.sub_part:2(string) parts.quantity:3(int) parts.rowid:4(int!null) + │ │ └── filters + │ │ └── eq [type=bool] + │ │ ├── variable: parts.part [type=string] + │ │ └── const: 'our_product' [type=string] + │ └── project + │ ├── columns: p.part:11(string!null) p.sub_part:12(string) p.quantity:13(int) + │ └── select + │ ├── columns: sub_part:8(string!null) part:9(string) quantity:10(int) p.part:11(string!null) p.sub_part:12(string) p.quantity:13(int) p.rowid:14(int!null) + │ ├── inner-join (hash) + │ │ ├── columns: sub_part:8(string) part:9(string) quantity:10(int) p.part:11(string) p.sub_part:12(string) p.quantity:13(int) p.rowid:14(int!null) + │ │ ├── with-scan &1 (included_parts) + │ │ │ ├── columns: sub_part:8(string) part:9(string) quantity:10(int) + │ │ │ └── mapping: + │ │ │ ├── sub_part:5(string) => sub_part:8(string) + │ │ │ ├── part:6(string) => part:9(string) + │ │ │ └── quantity:7(int) => quantity:10(int) + │ │ ├── scan p + │ │ │ └── columns: p.part:11(string) p.sub_part:12(string) p.quantity:13(int) p.rowid:14(int!null) + │ │ └── filters (true) + │ └── filters + │ └── eq [type=bool] + │ ├── variable: p.part [type=string] + │ └── variable: sub_part [type=string] + └── group-by + ├── columns: sub_part:15(string) sum:18(decimal) + ├── grouping columns: sub_part:15(string) + ├── project + │ ├── columns: sub_part:15(string) quantity:17(int) + │ └── with-scan &2 (included_parts) + │ ├── columns: sub_part:15(string) part:16(string) quantity:17(int) + │ └── mapping: + │ ├── sub_part:5(string) => sub_part:15(string) + │ ├── part:6(string) => part:16(string) + │ └── quantity:7(int) => quantity:17(int) + └── aggregations + └── sum [type=decimal] + └── variable: quantity [type=int] + + +exec-ddl +CREATE TABLE graph (id INT PRIMARY KEY, link INT, data STRING) +---- + +build +WITH RECURSIVE search_graph(id, link, data, depth, path, cycle) AS ( + SELECT g.id, g.link, g.data, 1, + ARRAY[g.id], + false + FROM graph g + UNION ALL + SELECT g.id, g.link, g.data, sg.depth + 1, + path || g.id, + g.id = ANY(path) + FROM graph g, search_graph sg + WHERE g.id = sg.link AND NOT cycle +) +SELECT * FROM search_graph +---- +with &2 (search_graph) + ├── columns: id:25(int) link:26(int) data:27(string) depth:28(int) path:29(int[]) cycle:30(bool) + ├── recursive-c-t-e + │ ├── columns: id:7(int) link:8(int) data:9(string) depth:10(int) path:11(int[]) cycle:12(bool) + │ ├── working table binding: &1 + │ ├── initial columns: g.id:1(int) g.link:2(int) g.data:3(string) "?column?":4(int) array:5(int[]) bool:6(bool) + │ ├── recursive columns: g.id:13(int) g.link:14(int) g.data:15(string) "?column?":22(int) "?column?":23(int[]) "?column?":24(bool) + │ ├── project + │ │ ├── columns: "?column?":4(int!null) array:5(int[]) bool:6(bool!null) g.id:1(int!null) g.link:2(int) g.data:3(string) + │ │ ├── scan g + │ │ │ └── columns: g.id:1(int!null) g.link:2(int) g.data:3(string) + │ │ └── projections + │ │ ├── const: 1 [type=int] + │ │ ├── array: [type=int[]] + │ │ │ └── variable: g.id [type=int] + │ │ └── false [type=bool] + │ └── project + │ ├── columns: "?column?":22(int) "?column?":23(int[]) "?column?":24(bool) g.id:13(int!null) g.link:14(int) g.data:15(string) + │ ├── select + │ │ ├── columns: g.id:13(int!null) g.link:14(int) g.data:15(string) id:16(int) link:17(int!null) data:18(string) depth:19(int) path:20(int[]) cycle:21(bool!null) + │ │ ├── inner-join (hash) + │ │ │ ├── columns: g.id:13(int!null) g.link:14(int) g.data:15(string) id:16(int) link:17(int) data:18(string) depth:19(int) path:20(int[]) cycle:21(bool) + │ │ │ ├── scan g + │ │ │ │ └── columns: g.id:13(int!null) g.link:14(int) g.data:15(string) + │ │ │ ├── with-scan &1 (search_graph) + │ │ │ │ ├── columns: id:16(int) link:17(int) data:18(string) depth:19(int) path:20(int[]) cycle:21(bool) + │ │ │ │ └── mapping: + │ │ │ │ ├── id:7(int) => id:16(int) + │ │ │ │ ├── link:8(int) => link:17(int) + │ │ │ │ ├── data:9(string) => data:18(string) + │ │ │ │ ├── depth:10(int) => depth:19(int) + │ │ │ │ ├── path:11(int[]) => path:20(int[]) + │ │ │ │ └── cycle:12(bool) => cycle:21(bool) + │ │ │ └── filters (true) + │ │ └── filters + │ │ └── and [type=bool] + │ │ ├── eq [type=bool] + │ │ │ ├── variable: g.id [type=int] + │ │ │ └── variable: link [type=int] + │ │ └── not [type=bool] + │ │ └── variable: cycle [type=bool] + │ └── projections + │ ├── plus [type=int] + │ │ ├── variable: depth [type=int] + │ │ └── const: 1 [type=int] + │ ├── concat [type=int[]] + │ │ ├── variable: path [type=int[]] + │ │ └── variable: g.id [type=int] + │ └── any-scalar: eq [type=bool] + │ ├── variable: g.id [type=int] + │ └── variable: path [type=int[]] + └── with-scan &2 (search_graph) + ├── columns: id:25(int) link:26(int) data:27(string) depth:28(int) path:29(int[]) cycle:30(bool) + └── mapping: + ├── id:7(int) => id:25(int) + ├── link:8(int) => link:26(int) + ├── data:9(string) => data:27(string) + ├── depth:10(int) => depth:28(int) + ├── path:11(int[]) => path:29(int[]) + └── cycle:12(bool) => cycle:30(bool) + +# Test where initial query has duplicate columns. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 0, 0 + UNION ALL + SELECT a+1, b+10 FROM cte WHERE a < 5 +) SELECT * FROM cte; +---- +with &2 (cte) + ├── columns: a:8(int) b:9(int) + ├── recursive-c-t-e + │ ├── columns: a:2(int) b:3(int) + │ ├── working table binding: &1 + │ ├── initial columns: "?column?":1(int) "?column?":1(int) + │ ├── recursive columns: "?column?":6(int) "?column?":7(int) + │ ├── project + │ │ ├── columns: "?column?":1(int!null) + │ │ ├── values + │ │ │ └── tuple [type=tuple] + │ │ └── projections + │ │ └── const: 0 [type=int] + │ └── project + │ ├── columns: "?column?":6(int) "?column?":7(int) + │ ├── select + │ │ ├── columns: a:4(int!null) b:5(int) + │ │ ├── with-scan &1 (cte) + │ │ │ ├── columns: a:4(int) b:5(int) + │ │ │ └── mapping: + │ │ │ ├── a:2(int) => a:4(int) + │ │ │ └── b:3(int) => b:5(int) + │ │ └── filters + │ │ └── lt [type=bool] + │ │ ├── variable: a [type=int] + │ │ └── const: 5 [type=int] + │ └── projections + │ ├── plus [type=int] + │ │ ├── variable: a [type=int] + │ │ └── const: 1 [type=int] + │ └── plus [type=int] + │ ├── variable: b [type=int] + │ └── const: 10 [type=int] + └── with-scan &2 (cte) + ├── columns: a:8(int) b:9(int) + └── mapping: + ├── a:2(int) => a:8(int) + └── b:3(int) => b:9(int) + +# Test where recursive query has duplicate columns. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 0, 1 + UNION ALL + SELECT a+1, a+1 FROM cte WHERE a < 5 +) SELECT * FROM cte; +---- +with &2 (cte) + ├── columns: a:8(int) b:9(int) + ├── recursive-c-t-e + │ ├── columns: a:3(int) b:4(int) + │ ├── working table binding: &1 + │ ├── initial columns: "?column?":1(int) "?column?":2(int) + │ ├── recursive columns: "?column?":7(int) "?column?":7(int) + │ ├── project + │ │ ├── columns: "?column?":1(int!null) "?column?":2(int!null) + │ │ ├── values + │ │ │ └── tuple [type=tuple] + │ │ └── projections + │ │ ├── const: 0 [type=int] + │ │ └── const: 1 [type=int] + │ └── project + │ ├── columns: "?column?":7(int) + │ ├── select + │ │ ├── columns: a:5(int!null) b:6(int) + │ │ ├── with-scan &1 (cte) + │ │ │ ├── columns: a:5(int) b:6(int) + │ │ │ └── mapping: + │ │ │ ├── a:3(int) => a:5(int) + │ │ │ └── b:4(int) => b:6(int) + │ │ └── filters + │ │ └── lt [type=bool] + │ │ ├── variable: a [type=int] + │ │ └── const: 5 [type=int] + │ └── projections + │ └── plus [type=int] + │ ├── variable: a [type=int] + │ └── const: 1 [type=int] + └── with-scan &2 (cte) + ├── columns: a:8(int) b:9(int) + └── mapping: + ├── a:3(int) => a:8(int) + └── b:4(int) => b:9(int) + +# Allow non-recursive CTE when RECURSIVE is used. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1, 2 +) SELECT * FROM cte; +---- +with &2 (cte) + ├── columns: a:3(int!null) b:4(int!null) + ├── project + │ ├── columns: "?column?":1(int!null) "?column?":2(int!null) + │ ├── values + │ │ └── tuple [type=tuple] + │ └── projections + │ ├── const: 1 [type=int] + │ └── const: 2 [type=int] + └── with-scan &2 (cte) + ├── columns: a:3(int!null) b:4(int!null) + └── mapping: + ├── "?column?":1(int) => a:3(int) + └── "?column?":2(int) => b:4(int) + +# Allow non-recursive CTE even when it has UNION ALL. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1, 2 + UNION ALL + SELECT 3, 4 +) SELECT * FROM cte; +---- +with &2 (cte) + ├── columns: a:9(int!null) b:10(int!null) + ├── union + │ ├── columns: "?column?":7(int!null) "?column?":8(int!null) + │ ├── left columns: "?column?":1(int) "?column?":2(int) + │ ├── right columns: "?column?":5(int) "?column?":6(int) + │ ├── project + │ │ ├── columns: "?column?":1(int!null) "?column?":2(int!null) + │ │ ├── values + │ │ │ └── tuple [type=tuple] + │ │ └── projections + │ │ ├── const: 1 [type=int] + │ │ └── const: 2 [type=int] + │ └── project + │ ├── columns: "?column?":5(int!null) "?column?":6(int!null) + │ ├── values + │ │ └── tuple [type=tuple] + │ └── projections + │ ├── const: 3 [type=int] + │ └── const: 4 [type=int] + └── with-scan &2 (cte) + ├── columns: a:9(int!null) b:10(int!null) + └── mapping: + ├── "?column?":7(int) => a:9(int) + └── "?column?":8(int) => b:10(int) + +# Error cases. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1+a, 1+b FROM cte +) SELECT * FROM cte; +---- +error (42601): recursive query "cte" does not have the form non-recursive-term UNION ALL recursive-term + +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1, 2 + UNION + SELECT 1+a, 1+b FROM cte +) SELECT * FROM cte; +---- +error (42601): recursive query "cte" does not have the form non-recursive-term UNION ALL recursive-term + +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1+a, 1+b FROM cte + UNION ALL + SELECT 3, 4 +) SELECT * FROM cte; +---- +error (42601): recursive reference to query "cte" must not appear within its non-recursive term + +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1, 2 + UNION ALL + SELECT c1.a+c2.a, c1.b+c2.b FROM cte AS c1, cte AS c2 +) SELECT * FROM cte; +---- +error (42601): recursive reference to query "cte" must not appear more than once + +# If we really need to reference the working table multiple times, we can use +# an inner WITH. +build +WITH RECURSIVE cte(a, b) AS ( + SELECT 1, 2 + UNION ALL + (WITH foo AS (SELECT * FROM cte) SELECT c1.a+c2.a, c1.b+c2.b FROM foo AS c1, foo AS c2) +) SELECT * FROM cte; +---- +with &3 (cte) + ├── columns: a:13(int) b:14(int) + ├── recursive-c-t-e + │ ├── columns: a:3(int) b:4(int) + │ ├── working table binding: &1 + │ ├── initial columns: "?column?":1(int) "?column?":2(int) + │ ├── recursive columns: "?column?":11(int) "?column?":12(int) + │ ├── project + │ │ ├── columns: "?column?":1(int!null) "?column?":2(int!null) + │ │ ├── values + │ │ │ └── tuple [type=tuple] + │ │ └── projections + │ │ ├── const: 1 [type=int] + │ │ └── const: 2 [type=int] + │ └── with &2 (foo) + │ ├── columns: "?column?":11(int) "?column?":12(int) + │ ├── with-scan &1 (cte) + │ │ ├── columns: a:5(int) b:6(int) + │ │ └── mapping: + │ │ ├── a:3(int) => a:5(int) + │ │ └── b:4(int) => b:6(int) + │ └── project + │ ├── columns: "?column?":11(int) "?column?":12(int) + │ ├── inner-join (hash) + │ │ ├── columns: a:7(int) b:8(int) a:9(int) b:10(int) + │ │ ├── with-scan &2 (foo) + │ │ │ ├── columns: a:7(int) b:8(int) + │ │ │ └── mapping: + │ │ │ ├── a:5(int) => a:7(int) + │ │ │ └── b:6(int) => b:8(int) + │ │ ├── with-scan &2 (foo) + │ │ │ ├── columns: a:9(int) b:10(int) + │ │ │ └── mapping: + │ │ │ ├── a:5(int) => a:9(int) + │ │ │ └── b:6(int) => b:10(int) + │ │ └── filters (true) + │ └── projections + │ ├── plus [type=int] + │ │ ├── variable: a [type=int] + │ │ └── variable: a [type=int] + │ └── plus [type=int] + │ ├── variable: b [type=int] + │ └── variable: b [type=int] + └── with-scan &3 (cte) + ├── columns: a:13(int) b:14(int) + └── mapping: + ├── a:3(int) => a:13(int) + └── b:4(int) => b:14(int) diff --git a/pkg/sql/opt/optbuilder/union.go b/pkg/sql/opt/optbuilder/union.go index 76459f16963c..0ad5e095c95a 100644 --- a/pkg/sql/opt/optbuilder/union.go +++ b/pkg/sql/opt/optbuilder/union.go @@ -18,25 +18,27 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/types" ) -// buildUnion builds a set of memo groups that represent the given union +// buildUnionClause builds a set of memo groups that represent the given union // clause. // // See Builder.buildStmt for a description of the remaining input and // return values. -func (b *Builder) buildUnion( +func (b *Builder) buildUnionClause( clause *tree.UnionClause, desiredTypes []*types.T, inScope *scope, ) (outScope *scope) { leftScope := b.buildSelect(clause.Left, desiredTypes, inScope) // Try to propagate types left-to-right, if we didn't already have desired // types. if len(desiredTypes) == 0 { - desiredTypes = make([]*types.T, len(leftScope.cols)) - for i := range leftScope.cols { - desiredTypes[i] = leftScope.cols[i].typ - } + desiredTypes = leftScope.makeColumnTypes() } rightScope := b.buildSelect(clause.Right, desiredTypes, inScope) + return b.buildSetOp(clause.Type, clause.All, inScope, leftScope, rightScope) +} +func (b *Builder) buildSetOp( + typ tree.UnionType, all bool, inScope, leftScope, rightScope *scope, +) (outScope *scope) { // Remove any hidden columns, as they are not included in the Union. leftScope.removeHiddenCols() rightScope.removeHiddenCols() @@ -55,7 +57,7 @@ func (b *Builder) buildUnion( leftScope, rightScope, true, /* tolerateUnknownLeft */ true, /* tolerateUnknownRight */ - clause.Type.String(), + typ.String(), ) if propagateTypesLeft { @@ -69,7 +71,7 @@ func (b *Builder) buildUnion( // values from both the left and right relations). This is not necessary for // INTERSECT or EXCEPT, since these operations are basically filters on the // left relation. - if clause.Type == tree.UnionOp { + if typ == tree.UnionOp { outScope.cols = make([]scopeColumn, 0, len(leftScope.cols)) for i := range leftScope.cols { c := &leftScope.cols[i] @@ -89,8 +91,8 @@ func (b *Builder) buildUnion( right := rightScope.expr.(memo.RelExpr) private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols} - if clause.All { - switch clause.Type { + if all { + switch typ { case tree.UnionOp: outScope.expr = b.factory.ConstructUnionAll(left, right, &private) case tree.IntersectOp: @@ -99,7 +101,7 @@ func (b *Builder) buildUnion( outScope.expr = b.factory.ConstructExceptAll(left, right, &private) } } else { - switch clause.Type { + switch typ { case tree.UnionOp: outScope.expr = b.factory.ConstructUnion(left, right, &private) case tree.IntersectOp: diff --git a/pkg/sql/opt/optbuilder/update.go b/pkg/sql/opt/optbuilder/update.go index 07eb2b26a3aa..4796047994e2 100644 --- a/pkg/sql/opt/optbuilder/update.go +++ b/pkg/sql/opt/optbuilder/update.go @@ -76,7 +76,7 @@ func (b *Builder) buildUpdate(upd *tree.Update, inScope *scope) (outScope *scope var ctes []cteSource if upd.With != nil { - inScope, ctes = b.buildCTE(upd.With.CTEList, inScope) + inScope, ctes = b.buildCTEs(upd.With, inScope) } // UPDATE xx AS yy - we want to know about xx (tn) because diff --git a/pkg/sql/opt/optbuilder/with.go b/pkg/sql/opt/optbuilder/with.go index 529f807203ab..ae7a58a4da1e 100644 --- a/pkg/sql/opt/optbuilder/with.go +++ b/pkg/sql/opt/optbuilder/with.go @@ -11,6 +11,8 @@ package optbuilder import ( + "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" + "github.com/cockroachdb/cockroach/pkg/sql/opt/props" "github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -18,9 +20,160 @@ import ( "github.com/cockroachdb/errors" ) -// getCTECols renames a presentation for the scope, renaming the columns to +func (b *Builder) buildCTE( + cte *tree.CTE, inScope *scope, isRecursive bool, +) (memo.RelExpr, physical.Presentation) { + if !isRecursive { + cteScope := b.buildStmt(cte.Stmt, nil /* desiredTypes */, inScope) + cteScope.removeHiddenCols() + b.dropOrderingAndExtraCols(cteScope) + return cteScope.expr, b.getCTECols(cteScope, cte.Name) + } + + // WITH RECURSIVE queries are always of the form: + // + // WITH RECURSIVE name(cols) AS ( + // initial_query + // UNION ALL + // recursive_query + // ) + // + // Recursive CTE evaluation (paraphrased from postgres docs): + // 1. Evaluate the initial query; emit the results and also save them in + // a "working" table. + // 2. So long as the working table is not empty: + // * evaluate the recursive query, substituting the current contents of + // the working table for the recursive self-reference. + // * emit all resulting rows, and save them as the next iteration's + // working table. + // + // Note however, that a non-recursive CTE can be used even when RECURSIVE is + // specified (particularly useful when there are multiple CTEs defined). + // Handling this while having decent error messages is tricky. + + // Generate an id for the recursive CTE reference. This is the id through + // which the recursive expression refers to the current working table + // (via WithScan). + withID := b.factory.Memo().NextWithID() + + // cteScope allows recursive references to this CTE. + cteScope := inScope.push() + cteSrc := &cteSource{ + id: withID, + name: cte.Name, + } + cteScope.ctes = map[string]*cteSource{cte.Name.Alias.String(): cteSrc} + + initial, recursive, ok := b.splitRecursiveCTE(cte.Stmt) + if !ok { + // Build this as a non-recursive CTE, but throw a proper error message if it + // does have a recursive reference. + cteSrc.onRef = func() { + panic(pgerror.Newf( + pgcode.Syntax, + "recursive query %q does not have the form non-recursive-term UNION ALL recursive-term", + cte.Name.Alias, + )) + } + return b.buildCTE(cte, cteScope, false /* recursive */) + } + + // Set up an error if the initial part has a recursive reference. + cteSrc.onRef = func() { + panic(pgerror.Newf( + pgcode.Syntax, + "recursive reference to query %q must not appear within its non-recursive term", + cte.Name.Alias, + )) + } + initialScope := b.buildStmt(initial, nil /* desiredTypes */, cteScope) + + initialScope.removeHiddenCols() + b.dropOrderingAndExtraCols(initialScope) + + // The properties of the binding are tricky: the recursive expression is + // invoked repeatedly and these must hold each time. We can't use the initial + // expression's properties directly, as those only hold the first time the + // recursive query is executed. We can't really say too much about what the + // working table contains, except that it has at least one row (the recursive + // query is never invoked with an empty working table). + bindingProps := &props.Relational{} + bindingProps.OutputCols = initialScope.colSet() + bindingProps.Cardinality = props.AnyCardinality.AtLeast(props.OneCardinality) + // We don't really know the input row count, except for the first time we run + // the recursive query. We don't have anything better though. + bindingProps.Stats.RowCount = initialScope.expr.Relational().Stats.RowCount + cteSrc.bindingProps = bindingProps + + cteSrc.cols = b.getCTECols(initialScope, cte.Name) + + outScope := inScope.push() + + initialTypes := initialScope.makeColumnTypes() + + // Synthesize new output columns (because they contain values from both the + // initial and the recursive relations). These columns will also be used to + // refer to the working table (from the recursive query); we can't use the + // initial columns directly because they might contain duplicate IDs (e.g. + // consider initial query SELECT 0, 0). + for i, c := range cteSrc.cols { + newCol := b.synthesizeColumn(outScope, c.Alias, initialTypes[i], nil /* expr */, nil /* scalar */) + cteSrc.cols[i].ID = newCol.id + } + + // We want to check if the recursive query is actually recursive. This is for + // annoying cases like `SELECT 1 UNION ALL SELECT 2`. + numRefs := 0 + cteSrc.onRef = func() { + numRefs++ + } + + recursiveScope := b.buildSelect(recursive, initialTypes /* desiredTypes */, cteScope) + + if numRefs == 0 { + // Build this as a non-recursive CTE. + cteScope := b.buildSetOp(tree.UnionOp, false /* all */, inScope, initialScope, recursiveScope) + return cteScope.expr, b.getCTECols(cteScope, cte.Name) + } + + if numRefs != 1 { + // We disallow multiple recursive references for consistency with postgres. + panic(pgerror.Newf( + pgcode.Syntax, + "recursive reference to query %q must not appear more than once", + cte.Name.Alias, + )) + } + + recursiveScope.removeHiddenCols() + b.dropOrderingAndExtraCols(recursiveScope) + + // We allow propagation of types from the initial query to the recursive + // query. + _, propagateToRight := b.checkTypesMatch(initialScope, recursiveScope, + false, /* tolerateUnknownLeft */ + true, /* tolerateUnknownRight */ + "UNION", + ) + if propagateToRight { + recursiveScope = b.propagateTypes(recursiveScope /* dst */, initialScope /* src */) + } + + private := memo.RecursiveCTEPrivate{ + Name: string(cte.Name.Alias), + WithID: withID, + InitialCols: colsToColList(initialScope.cols), + RecursiveCols: colsToColList(recursiveScope.cols), + OutCols: colsToColList(outScope.cols), + } + + expr := b.factory.ConstructRecursiveCTE(initialScope.expr, recursiveScope.expr, &private) + return expr, cteSrc.cols +} + +// getCTECols returns a presentation for the scope, renaming the columns to // those provided in the AliasClause (if any). Throws an error if there is a -// mistmatch in the number of columns. +// mismatch in the number of columns. func (b *Builder) getCTECols(cteScope *scope, name tree.AliasClause) physical.Presentation { presentation := cteScope.makePresentation() @@ -49,3 +202,23 @@ func (b *Builder) getCTECols(cteScope *scope, name tree.AliasClause) physical.Pr } return presentation } + +// splitRecursiveCTE splits a CTE statement of the form +// initial_query UNION ALL recursive_query +// into the initial and recursive parts. If the statement is not of this form, +// returns ok=false. +func (b *Builder) splitRecursiveCTE( + stmt tree.Statement, +) (initial, recursive *tree.Select, ok bool) { + sel, ok := stmt.(*tree.Select) + // The form above doesn't allow for "outer" WITH, ORDER BY, or LIMIT + // clauses. + if !ok || sel.With != nil || sel.OrderBy != nil || sel.Limit != nil { + return nil, nil, false + } + union, ok := sel.Select.(*tree.UnionClause) + if !ok || union.Type != tree.UnionOp || !union.All { + return nil, nil, false + } + return union.Left, union.Right, true +} diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index d7a50cc4d4d7..6713e4e21e6a 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -808,6 +808,17 @@ func (ef *execFactory) ConstructScanBuffer(ref exec.Node, label string) (exec.No }, nil } +// ConstructRecursiveCTE is part of the exec.Factory interface. +func (ef *execFactory) ConstructRecursiveCTE( + initial exec.Node, fn exec.RecursiveCTEIterationFn, label string, +) (exec.Node, error) { + return &recursiveCTENode{ + initial: initial.(planNode), + genIterationFn: fn, + label: label, + }, nil +} + // ConstructProjectSet is part of the exec.Factory interface. func (ef *execFactory) ConstructProjectSet( n exec.Node, exprs tree.TypedExprs, zipCols sqlbase.ResultColumns, numColsPerGen []int, diff --git a/pkg/sql/opt_limits.go b/pkg/sql/opt_limits.go index 4e6a098179ab..275a9368fcc7 100644 --- a/pkg/sql/opt_limits.go +++ b/pkg/sql/opt_limits.go @@ -245,6 +245,7 @@ func (p *planner) applyLimit(plan planNode, numRows int64, soft bool) { case *showTraceNode: case *scatterNode: case *scanBufferNode: + case *recursiveCTENode: case *unsplitAllNode: case *applyJoinNode, *lookupJoinNode, *zigzagJoinNode, *saveTableNode: diff --git a/pkg/sql/parser/parse_test.go b/pkg/sql/parser/parse_test.go index 6fdc19506bc5..3515bab5f32c 100644 --- a/pkg/sql/parser/parse_test.go +++ b/pkg/sql/parser/parse_test.go @@ -1355,6 +1355,10 @@ func TestParse(t *testing.T) { // Regression for #15926 {`SELECT * FROM ((t1 NATURAL JOIN t2 WITH ORDINALITY AS o1)) WITH ORDINALITY AS o2`}, + + {`WITH cte AS (SELECT 1) SELECT * FROM cte`}, + {`WITH cte (x) AS (INSERT INTO abc VALUES (1, 2)), cte2 (y) AS (SELECT x + 1 FROM cte) SELECT * FROM cte, cte2`}, + {`WITH RECURSIVE cte (x) AS (SELECT 1), cte2 (y) AS (SELECT x + 1 FROM cte) SELECT 1`}, } var p parser.Parser // Verify that the same parser can be reused. for _, d := range testData { @@ -3121,8 +3125,6 @@ func TestUnimplementedSyntax(t *testing.T) { {`INSERT INTO a VALUES (1) ON CONFLICT (x) WHERE x > 3 DO NOTHING`, 32557, ``}, - {`WITH RECURSIVE a AS (TABLE b) SELECT c`, 21085, ``}, - {`UPDATE foo SET (a, a.b) = (1, 2)`, 27792, ``}, {`UPDATE foo SET a.b = 1`, 27792, ``}, {`UPDATE Foo SET x.y = z`, 27792, ``}, diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y index e716d15fb4cb..ac3c243e9548 100644 --- a/pkg/sql/parser/sql.y +++ b/pkg/sql/parser/sql.y @@ -6036,7 +6036,10 @@ with_clause: /* SKIP DOC */ $$.val = &tree.With{CTEList: $2.ctes()} } -| WITH RECURSIVE cte_list { return unimplementedWithIssue(sqllex, 21085) } +| WITH RECURSIVE cte_list + { + $$.val = &tree.With{Recursive: true, CTEList: $3.ctes()} + } cte_list: common_table_expr diff --git a/pkg/sql/plan.go b/pkg/sql/plan.go index 3a55946be6d8..dca3c7faeee7 100644 --- a/pkg/sql/plan.go +++ b/pkg/sql/plan.go @@ -166,6 +166,7 @@ var _ planNode = &limitNode{} var _ planNode = &max1RowNode{} var _ planNode = &ordinalityNode{} var _ planNode = &projectSetNode{} +var _ planNode = &recursiveCTENode{} var _ planNode = &relocateNode{} var _ planNode = &renameColumnNode{} var _ planNode = &renameDatabaseNode{} diff --git a/pkg/sql/plan_columns.go b/pkg/sql/plan_columns.go index b58ac15eda7e..4218b185781f 100644 --- a/pkg/sql/plan_columns.go +++ b/pkg/sql/plan_columns.go @@ -141,6 +141,8 @@ func getPlanColumns(plan planNode, mut bool) sqlbase.ResultColumns { return getPlanColumns(n.source, mut) case *scanBufferNode: return getPlanColumns(n.buffer, mut) + case *recursiveCTENode: + return getPlanColumns(n.initial, mut) case *rowSourceToPlanNode: return n.planCols diff --git a/pkg/sql/plan_physical_props.go b/pkg/sql/plan_physical_props.go index 5628ec1ea911..6de717819cc6 100644 --- a/pkg/sql/plan_physical_props.go +++ b/pkg/sql/plan_physical_props.go @@ -84,11 +84,12 @@ func planPhysicalProps(plan planNode) physicalProps { return n.props case *zigzagJoinNode: return n.props + + // Every other node simply has no guarantees on its output rows. case *applyJoinNode: case *bufferNode: case *scanBufferNode: - - // Every other node simply has no guarantees on its output rows. + case *recursiveCTENode: case *CreateUserNode: case *DropUserNode: case *alterIndexNode: diff --git a/pkg/sql/recursive_cte.go b/pkg/sql/recursive_cte.go new file mode 100644 index 000000000000..52bf37620bec --- /dev/null +++ b/pkg/sql/recursive_cte.go @@ -0,0 +1,133 @@ +// Copyright 2019 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package sql + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/sql/opt/exec" + "github.com/cockroachdb/cockroach/pkg/sql/rowcontainer" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" +) + +// recursiveCTENode implements the logic for a recursive CTE: +// 1. Evaluate the initial query; emit the results and also save them in +// a "working" table. +// 2. So long as the working table is not empty: +// * evaluate the recursive query, substituting the current contents of +// the working table for the recursive self-reference; +// * emit all resulting rows, and save them as the next iteration's +// working table. +// The recursive query tree is regenerated each time using a callback +// (implemented by the execbuilder). +type recursiveCTENode struct { + initial planNode + + genIterationFn exec.RecursiveCTEIterationFn + + label string + + recursiveCTERun +} + +type recursiveCTERun struct { + // workingRows contains the rows produced by the current iteration (aka the + // "working" table). + workingRows *rowcontainer.RowContainer + // nextRowIdx is the index inside workingRows of the next row to be returned + // by the operator. + nextRowIdx int + + initialDone bool + done bool +} + +func (n *recursiveCTENode) startExec(params runParams) error { + n.workingRows = rowcontainer.NewRowContainer( + params.EvalContext().Mon.MakeBoundAccount(), + sqlbase.ColTypeInfoFromResCols(getPlanColumns(n.initial, false /* mut */)), + 0, /* rowCapacity */ + ) + n.nextRowIdx = 0 + return nil +} + +func (n *recursiveCTENode) Next(params runParams) (bool, error) { + if err := params.p.cancelChecker.Check(); err != nil { + return false, err + } + + n.nextRowIdx++ + + if !n.initialDone { + ok, err := n.initial.Next(params) + if err != nil { + return false, err + } + if ok { + if _, err = n.workingRows.AddRow(params.ctx, n.initial.Values()); err != nil { + return false, err + } + return true, nil + } + n.initialDone = true + } + + if n.workingRows.Len() == 0 { + // Last iteration returned no rows. + return false, nil + } + + // There are more rows to return from the last iteration. + if n.nextRowIdx <= n.workingRows.Len() { + return true, nil + } + + // Let's run another iteration. + + lastWorkingRows := n.workingRows + defer lastWorkingRows.Close(params.ctx) + + n.workingRows = rowcontainer.NewRowContainer( + params.EvalContext().Mon.MakeBoundAccount(), + sqlbase.ColTypeInfoFromResCols(getPlanColumns(n.initial, false /* mut */)), + 0, /* rowCapacity */ + ) + + // Set up a bufferNode that can be used as a reference for a scanBufferNode. + buf := &bufferNode{ + // The plan here is only useful for planColumns, so it's ok to always use + // the initial plan. + plan: n.initial, + bufferedRows: lastWorkingRows, + label: n.label, + } + newPlan, err := n.genIterationFn(buf) + if err != nil { + return false, err + } + + if err := runPlanInsidePlan(params, newPlan.(*planTop), n.workingRows); err != nil { + return false, err + } + n.nextRowIdx = 1 + return n.workingRows.Len() > 0, nil +} + +func (n *recursiveCTENode) Values() tree.Datums { + return n.workingRows.At(n.nextRowIdx - 1) +} + +func (n *recursiveCTENode) Close(ctx context.Context) { + n.initial.Close(ctx) + n.workingRows.Close(ctx) +} diff --git a/pkg/sql/rowcontainer/datum_row_container.go b/pkg/sql/rowcontainer/datum_row_container.go index 9944c36f42da..b764573a8a1b 100644 --- a/pkg/sql/rowcontainer/datum_row_container.go +++ b/pkg/sql/rowcontainer/datum_row_container.go @@ -167,6 +167,10 @@ func (c *RowContainer) UnsafeReset(ctx context.Context) error { // Close releases the memory associated with the RowContainer. func (c *RowContainer) Close(ctx context.Context) { + if c == nil { + // Allow Close on an uninitialized container. + return + } c.chunks = nil c.varSizedColumns = nil c.memAcc.Close(ctx) diff --git a/pkg/sql/sem/tree/pretty.go b/pkg/sql/sem/tree/pretty.go index a45da2a6eb63..aec77b74bf02 100644 --- a/pkg/sql/sem/tree/pretty.go +++ b/pkg/sql/sem/tree/pretty.go @@ -659,7 +659,11 @@ func (node *With) docRow(p *PrettyCfg) pretty.TableRow { p.bracketKeyword("AS", " (", p.Doc(cte.Stmt), ")", ""), ) } - return p.row("WITH", p.commaSeparated(d...)) + kw := "WITH" + if node.Recursive { + kw = "WITH RECURSIVE" + } + return p.row(kw, p.commaSeparated(d...)) } func (node *Subquery) doc(p *PrettyCfg) pretty.Doc { diff --git a/pkg/sql/sem/tree/with.go b/pkg/sql/sem/tree/with.go index b71ce5bcd404..da07a777b2e6 100644 --- a/pkg/sql/sem/tree/with.go +++ b/pkg/sql/sem/tree/with.go @@ -12,7 +12,8 @@ package tree // With represents a WITH statement. type With struct { - CTEList []*CTE + Recursive bool + CTEList []*CTE } // CTE represents a common table expression inside of a WITH clause. @@ -27,6 +28,9 @@ func (node *With) Format(ctx *FmtCtx) { return } ctx.WriteString("WITH ") + if node.Recursive { + ctx.WriteString("RECURSIVE ") + } for i, cte := range node.CTEList { if i != 0 { ctx.WriteString(", ") @@ -34,6 +38,7 @@ func (node *With) Format(ctx *FmtCtx) { ctx.FormatNode(&cte.Name) ctx.WriteString(" AS (") ctx.FormatNode(cte.Stmt) - ctx.WriteString(") ") + ctx.WriteString(")") } + ctx.WriteByte(' ') } diff --git a/pkg/sql/walk.go b/pkg/sql/walk.go index 429abaaf7e87..70c258fe61a1 100644 --- a/pkg/sql/walk.go +++ b/pkg/sql/walk.go @@ -659,6 +659,12 @@ func (v *planVisitor) visitInternal(plan planNode, name string) { } n.plan = v.visit(n.plan) + case *recursiveCTENode: + if v.observer.attr != nil { + v.observer.attr(name, "label", n.label) + } + n.initial = v.visit(n.initial) + case *exportNode: if v.observer.attr != nil { v.observer.attr(name, "destination", n.fileName) @@ -796,6 +802,7 @@ var planNodeNames = map[reflect.Type]string{ reflect.TypeOf(&max1RowNode{}): "max1row", reflect.TypeOf(&ordinalityNode{}): "ordinality", reflect.TypeOf(&projectSetNode{}): "project set", + reflect.TypeOf(&recursiveCTENode{}): "recursive cte node", reflect.TypeOf(&relocateNode{}): "relocate", reflect.TypeOf(&renameColumnNode{}): "rename column", reflect.TypeOf(&renameDatabaseNode{}): "rename database",