From c4a93b6cf408391773e2c42a6b5a1f67c047078c Mon Sep 17 00:00:00 2001 From: Ridwan Sharif Date: Sun, 30 Jun 2019 13:57:10 -0400 Subject: [PATCH] opt: fetch minimal set of columns on returning mutations Previously, we used to fetch all columns when a mutation contained a `RETURNING` clause. This is an issue because it forces us to retrieve unnecessary data and creates extra contention. This change adds logic to compute the minimal set of required columns and fetches only those. This change passes down the minimal required return cols to SQL so it knows to only return columns that's requested. This fixes the output column calculation done by opt and makes sure the execPlan's output columns are the same as output cols of the opt plan. Fixes #30618. Unblocks #30624. Release note: None --- pkg/sql/delete.go | 22 +- pkg/sql/insert.go | 43 +- pkg/sql/opt/bench/stub_factory.go | 11 +- pkg/sql/opt/exec/execbuilder/mutation.go | 15 +- pkg/sql/opt/exec/execbuilder/testdata/ddl | 30 +- pkg/sql/opt/exec/execbuilder/testdata/delete | 19 +- pkg/sql/opt/exec/execbuilder/testdata/insert | 28 +- pkg/sql/opt/exec/execbuilder/testdata/orderby | 26 +- pkg/sql/opt/exec/execbuilder/testdata/update | 4 +- pkg/sql/opt/exec/execbuilder/testdata/upsert | 4 +- pkg/sql/opt/exec/factory.go | 8 +- pkg/sql/opt/memo/expr.go | 6 + pkg/sql/opt/memo/logical_props_builder.go | 6 +- pkg/sql/opt/norm/custom_funcs.go | 21 + pkg/sql/opt/norm/prune_cols.go | 40 ++ pkg/sql/opt/norm/rules/prune_cols.opt | 34 ++ pkg/sql/opt/norm/testdata/rules/prune_cols | 378 ++++++++++++++++-- pkg/sql/opt/xform/testdata/rules/join | 9 +- pkg/sql/opt_exec_factory.go | 226 +++++++++-- pkg/sql/rowcontainer/datum_row_container.go | 5 + pkg/sql/tablewriter_upsert_opt.go | 67 +++- pkg/sql/update.go | 29 +- 22 files changed, 848 insertions(+), 183 deletions(-) diff --git a/pkg/sql/delete.go b/pkg/sql/delete.go index 1572e80e76a6..288c01c36640 100644 --- a/pkg/sql/delete.go +++ b/pkg/sql/delete.go @@ -125,6 +125,9 @@ func (p *planner) Delete( requestedCols = desc.Columns } + // Since all columns are being returned, use the 1:1 mapping. See todo above. + rowIdxToRetIdx := mutationRowIdxToReturnIdx(requestedCols, requestedCols) + // Create the table deleter, which does the bulk of the work. rd, err := row.MakeDeleter( p.txn, desc, fkTables, requestedCols, row.CheckFKs, p.EvalContext(), &p.alloc, @@ -175,6 +178,7 @@ func (p *planner) Delete( td: tableDeleter{rd: rd, alloc: &p.alloc}, rowsNeeded: rowsNeeded, fastPathInterleaved: canDeleteFastInterleaved(desc, fkTables), + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -208,6 +212,13 @@ type deleteRun struct { // traceKV caches the current KV tracing flag. traceKV bool + + // rowIdxToRetIdx is the mapping from the columns returned by the deleter + // to the columns in the resultRowBuffer. A value of -1 is used to indicate + // that the column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value an the i-th index refers to the + // index of the resultRowBuffer where the i-th column is to be returned. + rowIdxToRetIdx []int } // maxDeleteBatchSize is the max number of entries in the KV batch for @@ -331,9 +342,16 @@ func (d *deleteNode) processSourceRow(params runParams, sourceVals tree.Datums) // contain additional columns for every newly dropped column not // visible. We do not want them to be available for RETURNING. // - // d.columns is guaranteed to only contain the requested + // d.run.rows.NumCols() is guaranteed to only contain the requested // public columns. - resultValues := sourceVals[:len(d.columns)] + resultValues := make(tree.Datums, d.run.rows.NumCols()) + for i := range d.run.rowIdxToRetIdx { + retIdx := d.run.rowIdxToRetIdx[i] + if retIdx >= 0 { + resultValues[retIdx] = sourceVals[i] + } + } + if _, err := d.run.rows.AddRow(params.ctx, resultValues); err != nil { return err } diff --git a/pkg/sql/insert.go b/pkg/sql/insert.go index f72dcbb99234..7dc6e9ef40df 100644 --- a/pkg/sql/insert.go +++ b/pkg/sql/insert.go @@ -286,6 +286,12 @@ func (p *planner) Insert( columns = sqlbase.ResultColumnsFromColDescs(desc.Columns) } + // Since all columns are being returned, use the 1:1 mapping. + tabIdxToRetIdx := make([]int, len(desc.Columns)) + for i := range tabIdxToRetIdx { + tabIdxToRetIdx[i] = i + } + // At this point, everything is ready for either an insertNode or an upserNode. var node batchedPlanNode @@ -315,8 +321,9 @@ func (p *planner) Insert( Cols: desc.Columns, Mapping: ri.InsertColIDtoRowIndex, }, - defaultExprs: defaultExprs, - insertCols: ri.InsertCols, + defaultExprs: defaultExprs, + insertCols: ri.InsertCols, + tabIdxToRetIdx: tabIdxToRetIdx, }, } node = in @@ -368,12 +375,21 @@ type insertRun struct { // into the row container above, when rowsNeeded is set. resultRowBuffer tree.Datums - // rowIdxToRetIdx is the mapping from the ordering of rows in - // insertCols to the ordering in the result rows, used when + // rowIdxToTabColIdx is the mapping from the ordering of rows in + // insertCols to the ordering in the rows in the table, used when // rowsNeeded is set to populate resultRowBuffer and the row // container. The return index is -1 if the column for the row - // index is not public. - rowIdxToRetIdx []int + // index is not public. This is used in conjunction with tabIdxToRetIdx + // to populate the resultRowBuffer. + rowIdxToTabColIdx []int + + // tabIdxToRetIdx is the mapping from the columns in the table to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the table column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value an the i-th index refers to the + // index of the resultRowBuffer where the i-th column of the table is + // to be returned. + tabIdxToRetIdx []int // traceKV caches the current KV tracing flag. traceKV bool @@ -406,7 +422,7 @@ func (n *insertNode) startExec(params runParams) error { // // Also we need to re-order the values in the source, ordered by // insertCols, when writing them to resultRowBuffer, ordered by - // n.columns. This uses the rowIdxToRetIdx mapping. + // n.columns. This uses the rowIdxToTabColIdx mapping. n.run.resultRowBuffer = make(tree.Datums, len(n.columns)) for i := range n.run.resultRowBuffer { @@ -419,13 +435,13 @@ func (n *insertNode) startExec(params runParams) error { colIDToRetIndex[cols[i].ID] = i } - n.run.rowIdxToRetIdx = make([]int, len(n.run.insertCols)) + n.run.rowIdxToTabColIdx = make([]int, len(n.run.insertCols)) for i, col := range n.run.insertCols { if idx, ok := colIDToRetIndex[col.ID]; !ok { // Column must be write only and not public. - n.run.rowIdxToRetIdx[i] = -1 + n.run.rowIdxToTabColIdx[i] = -1 } else { - n.run.rowIdxToRetIdx[i] = idx + n.run.rowIdxToTabColIdx[i] = idx } } } @@ -567,10 +583,13 @@ func (n *insertNode) processSourceRow(params runParams, sourceVals tree.Datums) // The downstream consumer will want the rows in the order of // the table descriptor, not that of insertCols. Reorder them // and ignore non-public columns. - if idx := n.run.rowIdxToRetIdx[i]; idx >= 0 { - n.run.resultRowBuffer[idx] = val + if tabIdx := n.run.rowIdxToTabColIdx[i]; tabIdx >= 0 { + if retIdx := n.run.tabIdxToRetIdx[tabIdx]; retIdx >= 0 { + n.run.resultRowBuffer[retIdx] = val + } } } + if _, err := n.run.rows.AddRow(params.ctx, n.run.resultRowBuffer); err != nil { return err } diff --git a/pkg/sql/opt/bench/stub_factory.go b/pkg/sql/opt/bench/stub_factory.go index 7eb7b9ad500f..10aad2b9fb90 100644 --- a/pkg/sql/opt/bench/stub_factory.go +++ b/pkg/sql/opt/bench/stub_factory.go @@ -222,8 +222,8 @@ func (f *stubFactory) ConstructInsert( input exec.Node, table cat.Table, insertCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, skipFKChecks bool, ) (exec.Node, error) { return struct{}{}, nil @@ -234,8 +234,8 @@ func (f *stubFactory) ConstructUpdate( table cat.Table, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { return struct{}{}, nil } @@ -247,14 +247,17 @@ func (f *stubFactory) ConstructUpsert( insertCols exec.ColumnOrdinalSet, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { return struct{}{}, nil } func (f *stubFactory) ConstructDelete( - input exec.Node, table cat.Table, fetchCols exec.ColumnOrdinalSet, rowsNeeded bool, + input exec.Node, + table cat.Table, + fetchCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, ) (exec.Node, error) { return struct{}{}, nil } diff --git a/pkg/sql/opt/exec/execbuilder/mutation.go b/pkg/sql/opt/exec/execbuilder/mutation.go index 1664bbdd828a..0d52d6898fab 100644 --- a/pkg/sql/opt/exec/execbuilder/mutation.go +++ b/pkg/sql/opt/exec/execbuilder/mutation.go @@ -47,14 +47,15 @@ func (b *Builder) buildInsert(ins *memo.InsertExpr) (execPlan, error) { tab := b.mem.Metadata().Table(ins.Table) insertOrds := ordinalSetFromColList(ins.InsertCols) checkOrds := ordinalSetFromColList(ins.CheckCols) + returnOrds := ordinalSetFromColList(ins.ReturnCols) // If we planned FK checks, disable the execution code for FK checks. disableExecFKs := len(ins.Checks) > 0 node, err := b.factory.ConstructInsert( input.root, tab, insertOrds, + returnOrds, checkOrds, - ins.NeedResults(), disableExecFKs, ) if err != nil { @@ -106,14 +107,15 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { tab := md.Table(upd.Table) fetchColOrds := ordinalSetFromColList(upd.FetchCols) updateColOrds := ordinalSetFromColList(upd.UpdateCols) + returnColOrds := ordinalSetFromColList(upd.ReturnCols) checkOrds := ordinalSetFromColList(upd.CheckCols) node, err := b.factory.ConstructUpdate( input.root, tab, fetchColOrds, updateColOrds, + returnColOrds, checkOrds, - upd.NeedResults(), ) if err != nil { return execPlan{}, err @@ -177,6 +179,7 @@ func (b *Builder) buildUpsert(ups *memo.UpsertExpr) (execPlan, error) { insertColOrds := ordinalSetFromColList(ups.InsertCols) fetchColOrds := ordinalSetFromColList(ups.FetchCols) updateColOrds := ordinalSetFromColList(ups.UpdateCols) + returnColOrds := ordinalSetFromColList(ups.ReturnCols) checkOrds := ordinalSetFromColList(ups.CheckCols) node, err := b.factory.ConstructUpsert( input.root, @@ -185,8 +188,8 @@ func (b *Builder) buildUpsert(ups *memo.UpsertExpr) (execPlan, error) { insertColOrds, fetchColOrds, updateColOrds, + returnColOrds, checkOrds, - ups.NeedResults(), ) if err != nil { return execPlan{}, err @@ -230,7 +233,8 @@ func (b *Builder) buildDelete(del *memo.DeleteExpr) (execPlan, error) { md := b.mem.Metadata() tab := md.Table(del.Table) fetchColOrds := ordinalSetFromColList(del.FetchCols) - node, err := b.factory.ConstructDelete(input.root, tab, fetchColOrds, del.NeedResults()) + returnColOrds := ordinalSetFromColList(del.ReturnCols) + node, err := b.factory.ConstructDelete(input.root, tab, fetchColOrds, returnColOrds) if err != nil { return execPlan{}, err } @@ -310,6 +314,9 @@ func appendColsWhenPresent(dst, src opt.ColList) opt.ColList { // indicating columns that are not involved in the mutation. func ordinalSetFromColList(colList opt.ColList) exec.ColumnOrdinalSet { var res util.FastIntSet + if colList == nil { + return res + } for i, col := range colList { if col != 0 { res.Add(i) diff --git a/pkg/sql/opt/exec/execbuilder/testdata/ddl b/pkg/sql/opt/exec/execbuilder/testdata/ddl index 6f6bbfd91f7e..5b75e361dbe6 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/ddl +++ b/pkg/sql/opt/exec/execbuilder/testdata/ddl @@ -233,19 +233,17 @@ COMMIT query TTTTT colnames EXPLAIN (VERBOSE) SELECT * FROM v ---- -tree field description columns ordering -render · · (k) · - │ render 0 k · · - └── run · · (k, v, z) · - └── update · · (k, v, z) · - │ table kv · · - │ set v · · - │ strategy updater · · - └── render · · (k, v, z, column7) · - │ render 0 k · · - │ render 1 v · · - │ render 2 z · · - │ render 3 444 · · - └── scan · · (k, v, z) · -· table kv@primary · · -· spans /1- · · +tree field description columns ordering +run · · (k) · + └── update · · (k) · + │ table kv · · + │ set v · · + │ strategy updater · · + └── render · · (k, v, z, column7) · + │ render 0 k · · + │ render 1 v · · + │ render 2 z · · + │ render 3 444 · · + └── scan · · (k, v, z) · +· table kv@primary · · +· spans /1- · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/delete b/pkg/sql/opt/exec/execbuilder/testdata/delete index a496b8a4df5e..0df0b0564a73 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/delete +++ b/pkg/sql/opt/exec/execbuilder/testdata/delete @@ -146,14 +146,11 @@ count · · query TTT EXPLAIN DELETE FROM indexed WHERE value = 5 LIMIT 10 RETURNING id ---- -render · · - └── run · · - └── delete · · - │ from indexed - │ strategy deleter - └── index-join · · - │ table indexed@primary - └── scan · · -· table indexed@indexed_value_idx -· spans /5-/6 -· limit 10 +run · · + └── delete · · + │ from indexed + │ strategy deleter + └── scan · · +· table indexed@indexed_value_idx +· spans /5-/6 +· limit 10 diff --git a/pkg/sql/opt/exec/execbuilder/testdata/insert b/pkg/sql/opt/exec/execbuilder/testdata/insert index 16539a03e2b1..e802c9a4bd79 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/insert +++ b/pkg/sql/opt/exec/execbuilder/testdata/insert @@ -480,20 +480,20 @@ CREATE TABLE xyz (x INT, y INT, z INT) query TTTTT EXPLAIN (VERBOSE) SELECT * FROM [INSERT INTO xyz SELECT a, b, c FROM abc RETURNING z] ORDER BY z ---- -render · · (z) +z - │ render 0 z · · - └── run · · (x, y, z, rowid[hidden]) · - └── insert · · (x, y, z, rowid[hidden]) · - │ into xyz(x, y, z, rowid) · · - │ strategy inserter · · - └── render · · (a, b, c, column9) +c - │ render 0 a · · - │ render 1 b · · - │ render 2 c · · - │ render 3 unique_rowid() · · - └── scan · · (a, b, c) +c -· table abc@abc_c_idx · · -· spans ALL · · +render · · (z) +z + │ render 0 z · · + └── run · · (z, rowid[hidden]) · + └── insert · · (z, rowid[hidden]) · + │ into xyz(x, y, z, rowid) · · + │ strategy inserter · · + └── render · · (a, b, c, column9) +c + │ render 0 a · · + │ render 1 b · · + │ render 2 c · · + │ render 3 unique_rowid() · · + └── scan · · (a, b, c) +c +· table abc@abc_c_idx · · +· spans ALL · · # ------------------------------------------------------------------------------ # Regression for #35364. This tests behavior that is different between the CBO diff --git a/pkg/sql/opt/exec/execbuilder/testdata/orderby b/pkg/sql/opt/exec/execbuilder/testdata/orderby index 1a370145d790..aac322bf3908 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/orderby +++ b/pkg/sql/opt/exec/execbuilder/testdata/orderby @@ -483,8 +483,8 @@ EXPLAIN (VERBOSE) INSERT INTO t(a, b) SELECT * FROM (SELECT 1 AS x, 2 AS y) ORDE ---- render · · (b) · │ render 0 b · · - └── run · · (a, b, c) · - └── insert · · (a, b, c) · + └── run · · (a, b) · + └── insert · · (a, b) · │ into t(a, b, c) · · │ strategy inserter · · └── values · · (x, y, column6) · @@ -496,23 +496,23 @@ render · · (b) · query TTTTT EXPLAIN (VERBOSE) DELETE FROM t WHERE a = 3 RETURNING b ---- -render · · (b) · - │ render 0 b · · - └── run · · (a, b, c) · - └── delete · · (a, b, c) · - │ from t · · - │ strategy deleter · · - └── scan · · (a, b, c) · -· table t@primary · · -· spans /3-/3/# · · +render · · (b) · + │ render 0 b · · + └── run · · (a, b) · + └── delete · · (a, b) · + │ from t · · + │ strategy deleter · · + └── scan · · (a, b) · +· table t@primary · · +· spans /3-/3/# · · query TTTTT EXPLAIN (VERBOSE) UPDATE t SET c = TRUE RETURNING b ---- render · · (b) · │ render 0 b · · - └── run · · (a, b, c) · - └── update · · (a, b, c) · + └── run · · (a, b) · + └── update · · (a, b) · │ table t · · │ set c · · │ strategy updater · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/update b/pkg/sql/opt/exec/execbuilder/testdata/update index a55117f4709a..d11f3696a773 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/update +++ b/pkg/sql/opt/exec/execbuilder/testdata/update @@ -308,8 +308,8 @@ EXPLAIN (VERBOSE) SELECT * FROM [ UPDATE abc SET a=c RETURNING a ] ORDER BY a ---- render · · (a) +a │ render 0 a · · - └── run · · (a, b, c, rowid[hidden]) · - └── update · · (a, b, c, rowid[hidden]) · + └── run · · (a, rowid[hidden]) · + └── update · · (a, rowid[hidden]) · │ table abc · · │ set a · · │ strategy updater · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/upsert b/pkg/sql/opt/exec/execbuilder/testdata/upsert index 41c44a9e95f6..ebe20303b4fe 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/upsert +++ b/pkg/sql/opt/exec/execbuilder/testdata/upsert @@ -327,8 +327,8 @@ EXPLAIN (VERBOSE) SELECT * FROM [UPSERT INTO xyz SELECT a, b, c FROM abc RETURNI ---- render · · (z) +z │ render 0 z · · - └── run · · (x, y, z, rowid[hidden]) · - └── upsert · · (x, y, z, rowid[hidden]) · + └── run · · (z, rowid[hidden]) · + └── upsert · · (z, rowid[hidden]) · │ into xyz(x, y, z, rowid) · · │ strategy opt upserter · · └── render · · (a, b, c, column9, a, b, c) +c diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index f8344ad064be..be7e6a8a596f 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -304,8 +304,8 @@ type Factory interface { input Node, table cat.Table, insertCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, skipFKChecks bool, ) (Node, error) @@ -326,8 +326,8 @@ type Factory interface { table cat.Table, fetchCols ColumnOrdinalSet, updateCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, ) (Node, error) // ConstructUpsert creates a node that implements an INSERT..ON CONFLICT or @@ -360,8 +360,8 @@ type Factory interface { insertCols ColumnOrdinalSet, fetchCols ColumnOrdinalSet, updateCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, ) (Node, error) // ConstructDelete creates a node that implements a DELETE statement. The @@ -373,7 +373,7 @@ type Factory interface { // as they appear in the table schema. The rowsNeeded parameter is true if a // RETURNING clause needs the deleted row(s) as output. ConstructDelete( - input Node, table cat.Table, fetchCols ColumnOrdinalSet, rowsNeeded bool, + input Node, table cat.Table, fetchCols ColumnOrdinalSet, returnCols ColumnOrdinalSet, ) (Node, error) // ConstructDeleteRange creates a node that efficiently deletes contiguous diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 9e1503afcf26..4c1bdc79986c 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -381,6 +381,12 @@ func (m *MutationPrivate) NeedResults() bool { return m.ReturnCols != nil } +// IsColumnOutput returns true if the i-th ordinal column should be part of the +// mutation's output columns. +func (m *MutationPrivate) IsColumnOutput(i int) bool { + return i < len(m.ReturnCols) && m.ReturnCols[i] != 0 +} + // MapToInputID maps from the ID of a returned column to the ID of the // corresponding input column that provides the value for it. If there is no // matching input column ID, MapToInputID returns 0. diff --git a/pkg/sql/opt/memo/logical_props_builder.go b/pkg/sql/opt/memo/logical_props_builder.go index 19326edff425..4b4d99fc52e6 100644 --- a/pkg/sql/opt/memo/logical_props_builder.go +++ b/pkg/sql/opt/memo/logical_props_builder.go @@ -1088,8 +1088,10 @@ func (b *logicalPropsBuilder) buildMutationProps(mutation RelExpr, rel *props.Re // -------------- // Only non-mutation columns are output columns. for i, n := 0, tab.ColumnCount(); i < n; i++ { - colID := private.Table.ColumnID(i) - rel.OutputCols.Add(colID) + if private.IsColumnOutput(i) { + colID := private.Table.ColumnID(i) + rel.OutputCols.Add(colID) + } } // Not Null Columns diff --git a/pkg/sql/opt/norm/custom_funcs.go b/pkg/sql/opt/norm/custom_funcs.go index 9e21703db674..06a8ce9098df 100644 --- a/pkg/sql/opt/norm/custom_funcs.go +++ b/pkg/sql/opt/norm/custom_funcs.go @@ -350,6 +350,27 @@ func (c *CustomFuncs) sharedProps(e opt.Expr) *props.Shared { panic(errors.AssertionFailedf("no logical properties available for node: %v", e)) } +// MutationTable returns the table upon which the mutation is applied. +func (c *CustomFuncs) MutationTable(private *memo.MutationPrivate) opt.TableID { + return private.Table +} + +// PrimaryKeyCols returns the key columns of the primary key of the table. +func (c *CustomFuncs) PrimaryKeyCols(table opt.TableID) opt.ColSet { + var primaryKeyCols opt.ColSet + tabID := c.mem.Metadata().TableMeta(table).MetaID + + // The columns of the primary key are always returned regardless of + // whether they are referenced. + tab := c.mem.Metadata().Table(table) + primaryIndex := tab.Index(0) + for i, n := 0, primaryIndex.KeyColumnCount(); i < n; i++ { + primaryKeyCols.Add(tabID.ColumnID(primaryIndex.Column(i).Ordinal)) + } + + return primaryKeyCols +} + // ---------------------------------------------------------------------- // // Ordering functions diff --git a/pkg/sql/opt/norm/prune_cols.go b/pkg/sql/opt/norm/prune_cols.go index 0f0818e01bf3..671a2639784d 100644 --- a/pkg/sql/opt/norm/prune_cols.go +++ b/pkg/sql/opt/norm/prune_cols.go @@ -509,3 +509,43 @@ func DerivePruneCols(e memo.RelExpr) opt.ColSet { return relProps.Rule.PruneCols } + +// CanPruneMutationReturnCols checks whether the mutation's return columns can +// be pruned. This is the pre-condition for the PruneMutationReturnCols rule. +func (c *CustomFuncs) CanPruneMutationReturnCols( + private *memo.MutationPrivate, needed opt.ColSet, +) bool { + if private.ReturnCols == nil { + return false + } + + tabID := c.mem.Metadata().TableMeta(private.Table).MetaID + for i := range private.ReturnCols { + if private.ReturnCols[i] != 0 && !needed.Contains(tabID.ColumnID(i)) { + return true + } + } + + return false +} + +// PruneMutationReturnCols rewrites the given mutation private to no longer +// keep ReturnCols that are not referenced by the RETURNING clause or are not +// part of the primary key. The caller must have already done the analysis to +// prove that such columns exist, by calling CanPruneMutationReturnCols. +func (c *CustomFuncs) PruneMutationReturnCols( + private *memo.MutationPrivate, needed opt.ColSet, +) *memo.MutationPrivate { + newPrivate := *private + newReturnCols := make(opt.ColList, len(private.ReturnCols)) + tabID := c.mem.Metadata().TableMeta(private.Table).MetaID + + for i := range private.ReturnCols { + if needed.Contains(tabID.ColumnID(i)) { + newReturnCols[i] = private.ReturnCols[i] + } + } + + newPrivate.ReturnCols = newReturnCols + return &newPrivate +} diff --git a/pkg/sql/opt/norm/rules/prune_cols.opt b/pkg/sql/opt/norm/rules/prune_cols.opt index 9e3bd969f087..8dcc5e6c590e 100644 --- a/pkg/sql/opt/norm/rules/prune_cols.opt +++ b/pkg/sql/opt/norm/rules/prune_cols.opt @@ -462,3 +462,37 @@ $checks $mutationPrivate ) + +# PruneReturningCols removes columns from the mutation operator's ReturnCols +# set if they are not used in the RETURNING clause of the mutation. +# Removing ReturnCols will then allow the PruneMutationFetchCols to be more +# conservative with the fetch columns. +[PruneMutationReturnCols, Normalize] +(Project + $input:(Insert | Update | Upsert | Delete + $innerInput:* + $checks:* + $mutationPrivate:* + ) + $projections:* + $passthrough:* & + (CanPruneMutationReturnCols + $mutationPrivate + $needed:(UnionCols3 + (PrimaryKeyCols (MutationTable $mutationPrivate)) + (ProjectionOuterCols $projections) + $passthrough + ) + ) +) +=> +(Project + ((OpName $input) + $innerInput + $checks + (PruneMutationReturnCols $mutationPrivate $needed) + ) + $projections + $passthrough +) + diff --git a/pkg/sql/opt/norm/testdata/rules/prune_cols b/pkg/sql/opt/norm/testdata/rules/prune_cols index 9c585e0c0052..2ed2154dd22c 100644 --- a/pkg/sql/opt/norm/testdata/rules/prune_cols +++ b/pkg/sql/opt/norm/testdata/rules/prune_cols @@ -1883,26 +1883,19 @@ delete mutation ├── key: (6) └── fd: (6)-->(7,9,10) -# No pruning when RETURNING clause is present. -# TODO(andyk): Need to prune output columns. -opt expect-not=(PruneMutationFetchCols,PruneMutationInputCols) +opt expect=(PruneMutationFetchCols,PruneMutationInputCols) DELETE FROM a RETURNING k, s ---- -project +delete a ├── columns: k:1(int!null) s:4(string) + ├── fetch columns: k:5(int) s:8(string) ├── side-effects, mutations ├── key: (1) ├── fd: (1)-->(4) - └── delete a - ├── columns: k:1(int!null) i:2(int) f:3(float) s:4(string) - ├── fetch columns: k:5(int) i:6(int) f:7(float) s:8(string) - ├── side-effects, mutations - ├── key: (1) - ├── fd: (1)-->(2-4) - └── scan a - ├── columns: k:5(int!null) i:6(int) f:7(float) s:8(string) - ├── key: (5) - └── fd: (5)-->(6-8) + └── scan a + ├── columns: k:5(int!null) s:8(string) + ├── key: (5) + └── fd: (5)-->(8) # Prune secondary family column not needed for the update. opt expect=(PruneMutationFetchCols,PruneMutationInputCols) @@ -1945,29 +1938,28 @@ update "family" └── a + 1 [type=int, outer=(6)] # Do not prune columns that must be returned. -# TODO(justin): in order to prune e here we need a PruneMutationReturnCols rule. -opt expect-not=PruneMutationFetchCols +opt expect=(PruneMutationFetchCols, PruneMutationReturnCols) UPDATE family SET c=c+1 RETURNING b ---- project ├── columns: b:2(int) ├── side-effects, mutations └── update "family" - ├── columns: a:1(int!null) b:2(int) c:3(int) d:4(int) e:5(int) - ├── fetch columns: a:6(int) b:7(int) c:8(int) d:9(int) e:10(int) + ├── columns: a:1(int!null) b:2(int) + ├── fetch columns: a:6(int) b:7(int) c:8(int) d:9(int) ├── update-mapping: │ └── column11:11 => c:3 ├── side-effects, mutations ├── key: (1) - ├── fd: (1)-->(2-5) + ├── fd: (1)-->(2) └── project - ├── columns: column11:11(int) a:6(int!null) b:7(int) c:8(int) d:9(int) e:10(int) + ├── columns: column11:11(int) a:6(int!null) b:7(int) c:8(int) d:9(int) ├── key: (6) - ├── fd: (6)-->(7-10), (8)-->(11) + ├── fd: (6)-->(7-9), (8)-->(11) ├── scan "family" - │ ├── columns: a:6(int!null) b:7(int) c:8(int) d:9(int) e:10(int) + │ ├── columns: a:6(int!null) b:7(int) c:8(int) d:9(int) │ ├── key: (6) - │ └── fd: (6)-->(7-10) + │ └── fd: (6)-->(7-9) └── projections └── c + 1 [type=int, outer=(8)] @@ -2115,9 +2107,9 @@ project ├── key: () ├── fd: ()-->(5) └── upsert "family" - ├── columns: a:1(int!null) b:2(int) c:3(int) d:4(int) e:5(int) + ├── columns: a:1(int!null) e:5(int) ├── canary column: 11 - ├── fetch columns: a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + ├── fetch columns: a:11(int) c:13(int) d:14(int) e:15(int) ├── insert-mapping: │ ├── column1:6 => a:1 │ ├── column2:7 => b:2 @@ -2128,24 +2120,21 @@ project │ └── upsert_c:19 => c:3 ├── return-mapping: │ ├── upsert_a:17 => a:1 - │ ├── upsert_b:18 => b:2 - │ ├── upsert_c:19 => c:3 - │ ├── upsert_d:20 => d:4 │ └── upsert_e:21 => e:5 ├── cardinality: [1 - 1] ├── side-effects, mutations ├── key: () - ├── fd: ()-->(1-5) + ├── fd: ()-->(1,5) └── project - ├── columns: upsert_a:17(int) upsert_b:18(int) upsert_c:19(int) upsert_d:20(int) upsert_e:21(int) column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + ├── columns: upsert_a:17(int) upsert_c:19(int) upsert_e:21(int) column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) c:13(int) d:14(int) e:15(int) ├── cardinality: [1 - 1] ├── key: () - ├── fd: ()-->(6-15,17-21) + ├── fd: ()-->(6-11,13-15,17,19,21) ├── left-join (hash) - │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) c:13(int) d:14(int) e:15(int) │ ├── cardinality: [1 - 1] │ ├── key: () - │ ├── fd: ()-->(6-15) + │ ├── fd: ()-->(6-11,13-15) │ ├── values │ │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) │ │ ├── cardinality: [1 - 1] @@ -2153,20 +2142,17 @@ project │ │ ├── fd: ()-->(6-10) │ │ └── (1, 2, 3, 4, 5) [type=tuple{int, int, int, int, int}] │ ├── scan "family" - │ │ ├── columns: a:11(int!null) b:12(int) c:13(int) d:14(int) e:15(int) + │ │ ├── columns: a:11(int!null) c:13(int) d:14(int) e:15(int) │ │ ├── constraint: /11: [/1 - /1] │ │ ├── cardinality: [0 - 1] │ │ ├── key: () - │ │ └── fd: ()-->(11-15) + │ │ └── fd: ()-->(11,13-15) │ └── filters (true) └── projections ├── CASE WHEN a IS NULL THEN column1 ELSE a END [type=int, outer=(6,11)] - ├── CASE WHEN a IS NULL THEN column2 ELSE b END [type=int, outer=(7,11,12)] ├── CASE WHEN a IS NULL THEN column3 ELSE 10 END [type=int, outer=(8,11)] - ├── CASE WHEN a IS NULL THEN column4 ELSE d END [type=int, outer=(9,11,14)] └── CASE WHEN a IS NULL THEN column5 ELSE e END [type=int, outer=(10,11,15)] - # Do not prune column in same secondary family as updated column. But prune # non-key column in primary family. opt expect=(PruneMutationFetchCols,PruneMutationInputCols) @@ -2254,3 +2240,319 @@ upsert mutation │ └── filters (true) └── projections └── CASE WHEN a IS NULL THEN column2 ELSE 10 END [type=int, outer=(7,10)] + +# ------------------------------------------------------------------------------ +# PruneMutationReturnCols +# ------------------------------------------------------------------------------ + +# Create a table with multiple column families the mutations can take advantage of. +exec-ddl +CREATE TABLE returning_test ( + a INT, + b INT, + c STRING, + d INT, + e INT, + f INT, + g INT, + FAMILY (a), + FAMILY (b), + FAMILY (c), + FAMILY (d, e, f, g), + UNIQUE (a) +) +---- + +# Fetch all the columns for the RETURN expression. +opt +UPDATE returning_test SET a = a + 1 RETURNING * +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) d:4(int) e:5(int) f:6(int) g:7(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) b:2(int) c:3(string) d:4(int) e:5(int) f:6(int) g:7(int) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1-7) + └── project + ├── columns: column17:17(int) a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9-15), (9)~~>(10-16), (9)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9-15), (9)~~>(10-16) + └── projections + └── a + 1 [type=int, outer=(9)] + + +# Fetch all the columns in the (d, e, f, g) family as d is being set. +opt +UPDATE returning_test SET d = a + d RETURNING a, d +---- +project + ├── columns: a:1(int) d:4(int) + ├── side-effects, mutations + ├── lax-key: (1,4) + ├── fd: (1)~~>(4) + └── update returning_test + ├── columns: a:1(int) d:4(int) rowid:8(int!null) + ├── fetch columns: a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => d:4 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1,4), (1)~~>(4,8) + └── project + ├── columns: column17:17(int) a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,12-15), (9)~~>(12-16), (9,12)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,12-15), (9)~~>(12-16) + └── projections + └── a + d [type=int, outer=(9,12)] + +# Fetch only whats being updated (not the (d, e, f, g) family). +opt +UPDATE returning_test SET a = a + d RETURNING a +---- +project + ├── columns: a:1(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) rowid:8(int!null) + ├── fetch columns: a:9(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1) + └── project + ├── columns: column17:17(int) a:9(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,17), (9)~~>(16,17) + ├── scan returning_test + │ ├── columns: a:9(int) d:12(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,12), (9)~~>(12,16) + └── projections + └── a + d [type=int, outer=(9,12)] + +# We only fetch the minimal set of columns which is (a, b, c, rowid). +opt +UPDATE returning_test SET (b, a) = (a, a + b) RETURNING a, b, c +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) + ├── side-effects, mutations + ├── lax-key: (1-3) + ├── fd: (2)~~>(1,3) + └── update returning_test + ├── columns: a:1(int) b:2(int) c:3(string) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) c:11(string) rowid:16(int) + ├── update-mapping: + │ ├── column17:17 => a:1 + │ └── a:9 => b:2 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1-3), (2)~~>(1,3,8) + └── project + ├── columns: column17:17(int) a:9(int) b:10(int) c:11(string) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9-11), (9)~~>(10,11,16), (9,10)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) c:11(string) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9-11), (9)~~>(10,11,16) + └── projections + └── a + b [type=int, outer=(9,10)] + + +# We apply the PruneMutationReturnCols rule multiple times, to get +# the minimal set of columns which is (a, rowid). Notice how c and b +# are pruned away. +opt +SELECT a FROM [SELECT a, b FROM [UPDATE returning_test SET a = a + 1 RETURNING a, b, c]] +---- +project + ├── columns: a:1(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) rowid:8(int!null) + ├── fetch columns: a:9(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1) + └── project + ├── columns: column17:17(int) a:9(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9), (9)~~>(16), (9)-->(17) + ├── scan returning_test@secondary + │ ├── columns: a:9(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9), (9)~~>(16) + └── projections + └── a + 1 [type=int, outer=(9)] + +# Check if the rule works as desired for other mutations. +opt +INSERT INTO returning_test VALUES (1, 2, 'c') ON CONFLICT (a) DO UPDATE SET a = excluded.a + returning_test.a RETURNING a, b, c +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) + ├── cardinality: [1 - 1] + ├── side-effects, mutations + ├── key: () + ├── fd: ()-->(1-3) + └── upsert returning_test + ├── columns: a:1(int) b:2(int) c:3(string) rowid:8(int!null) + ├── canary column: 21 + ├── fetch columns: a:14(int) b:15(int) c:16(string) rowid:21(int) + ├── insert-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── column12:12 => d:4 + │ ├── column12:12 => e:5 + │ ├── column12:12 => f:6 + │ ├── column12:12 => g:7 + │ └── column13:13 => rowid:8 + ├── update-mapping: + │ └── upsert_a:23 => a:1 + ├── return-mapping: + │ ├── upsert_a:23 => a:1 + │ ├── upsert_b:24 => b:2 + │ ├── upsert_c:25 => c:3 + │ └── upsert_rowid:30 => rowid:8 + ├── cardinality: [1 - 1] + ├── side-effects, mutations + ├── key: () + ├── fd: ()-->(1-3,8) + └── project + ├── columns: upsert_a:23(int) upsert_b:24(int) upsert_c:25(string) upsert_rowid:30(int) column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) rowid:21(int) + ├── cardinality: [1 - 1] + ├── side-effects + ├── key: () + ├── fd: ()-->(9-16,21,23-25,30) + ├── left-join (hash) + │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) rowid:21(int) + │ ├── cardinality: [1 - 1] + │ ├── side-effects + │ ├── key: () + │ ├── fd: ()-->(9-16,21) + │ ├── values + │ │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) + │ │ ├── cardinality: [1 - 1] + │ │ ├── side-effects + │ │ ├── key: () + │ │ ├── fd: ()-->(9-13) + │ │ └── (1, 2, 'c', CAST(NULL AS INT8), unique_rowid()) [type=tuple{int, int, string, int, int}] + │ ├── index-join returning_test + │ │ ├── columns: a:14(int!null) b:15(int) c:16(string) rowid:21(int!null) + │ │ ├── cardinality: [0 - 1] + │ │ ├── key: () + │ │ ├── fd: ()-->(14-16,21) + │ │ └── scan returning_test@secondary + │ │ ├── columns: a:14(int!null) rowid:21(int!null) + │ │ ├── constraint: /14: [/1 - /1] + │ │ ├── cardinality: [0 - 1] + │ │ ├── key: () + │ │ └── fd: ()-->(14,21) + │ └── filters (true) + └── projections + ├── CASE WHEN rowid IS NULL THEN column1 ELSE column1 + a END [type=int, outer=(9,14,21)] + ├── CASE WHEN rowid IS NULL THEN column2 ELSE b END [type=int, outer=(10,15,21)] + ├── CASE WHEN rowid IS NULL THEN column3 ELSE c END [type=string, outer=(11,16,21)] + └── CASE WHEN rowid IS NULL THEN column13 ELSE rowid END [type=int, outer=(13,21)] + +opt +DELETE FROM returning_test WHERE a < b + d RETURNING a, b, d +---- +project + ├── columns: a:1(int!null) b:2(int) d:4(int) + ├── side-effects, mutations + ├── key: (1) + ├── fd: (1)-->(2,4) + └── delete returning_test + ├── columns: a:1(int!null) b:2(int) d:4(int) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) d:12(int) rowid:16(int) + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1,2,4), (1)-->(2,4,8) + └── select + ├── columns: a:9(int!null) b:10(int) d:12(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,10,12), (9)-->(10,12,16) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) d:12(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,10,12), (9)~~>(10,12,16) + └── filters + └── a < (b + d) [type=bool, outer=(9,10,12), constraints=(/9: (/NULL - ])] + +opt +UPSERT INTO returning_test (a, b, c) VALUES (1, 2, 'c') RETURNING a, b, c, d +---- +project + ├── columns: a:1(int!null) b:2(int!null) c:3(string!null) d:4(int) + ├── cardinality: [1 - ] + ├── side-effects, mutations + ├── fd: ()-->(1-3) + └── upsert returning_test + ├── columns: a:1(int!null) b:2(int!null) c:3(string!null) d:4(int) rowid:8(int!null) + ├── canary column: 21 + ├── fetch columns: a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + ├── insert-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── column12:12 => d:4 + │ ├── column12:12 => e:5 + │ ├── column12:12 => f:6 + │ ├── column12:12 => g:7 + │ └── column13:13 => rowid:8 + ├── update-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ └── column3:11 => c:3 + ├── return-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── upsert_d:22 => d:4 + │ └── upsert_rowid:26 => rowid:8 + ├── cardinality: [1 - ] + ├── side-effects, mutations + ├── fd: ()-->(1-3) + └── project + ├── columns: upsert_d:22(int) upsert_rowid:26(int) column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + ├── cardinality: [1 - ] + ├── side-effects + ├── key: (21) + ├── fd: ()-->(9-13), (21)-->(14-17), (14)~~>(15-17,21), (17,21)-->(22), (21)-->(26) + ├── left-join (lookup returning_test) + │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + │ ├── key columns: [13] = [21] + │ ├── cardinality: [1 - ] + │ ├── side-effects + │ ├── key: (21) + │ ├── fd: ()-->(9-13), (21)-->(14-17), (14)~~>(15-17,21) + │ ├── values + │ │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) + │ │ ├── cardinality: [1 - 1] + │ │ ├── side-effects + │ │ ├── key: () + │ │ ├── fd: ()-->(9-13) + │ │ └── (1, 2, 'c', CAST(NULL AS INT8), unique_rowid()) [type=tuple{int, int, string, int, int}] + │ └── filters (true) + └── projections + ├── CASE WHEN rowid IS NULL THEN column12 ELSE d END [type=int, outer=(12,17,21)] + └── CASE WHEN rowid IS NULL THEN column13 ELSE rowid END [type=int, outer=(13,21)] diff --git a/pkg/sql/opt/xform/testdata/rules/join b/pkg/sql/opt/xform/testdata/rules/join index 57a27fa7e5a5..72bf8c44af5a 100644 --- a/pkg/sql/opt/xform/testdata/rules/join +++ b/pkg/sql/opt/xform/testdata/rules/join @@ -2116,24 +2116,21 @@ project ├── side-effects, mutations ├── fd: ()-->(21) ├── inner-join (hash) - │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ ├── columns: abc.rowid:8(int!null) │ ├── cardinality: [0 - 0] │ ├── side-effects, mutations - │ ├── fd: ()-->(5-7) │ ├── select - │ │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ │ ├── columns: abc.rowid:8(int!null) │ │ ├── cardinality: [0 - 0] │ │ ├── side-effects, mutations - │ │ ├── fd: ()-->(5-7) │ │ ├── insert abc - │ │ │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ │ │ ├── columns: abc.rowid:8(int!null) │ │ │ ├── insert-mapping: │ │ │ │ ├── "?column?":13 => abc.a:5 │ │ │ │ ├── column14:14 => abc.b:6 │ │ │ │ ├── column14:14 => abc.c:7 │ │ │ │ └── column15:15 => abc.rowid:8 │ │ │ ├── side-effects, mutations - │ │ │ ├── fd: ()-->(5-7) │ │ │ └── project │ │ │ ├── columns: column14:14(int) column15:15(int) "?column?":13(int!null) │ │ │ ├── side-effects diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index ff16ec9364b1..d91be92b3f1a 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -1153,15 +1153,41 @@ func (ef *execFactory) ConstructShowTrace(typ tree.ShowTraceType, compact bool) return node, nil } +// mutationRowIdxToReturnIdx returns the mapping from the origColDescs to the +// returningColDescs. -1 is used for columns not part of the returnColDescs. +// It is the responsibility of the caller to ensure a a mapping is possible, +// that is, the returningColDescs is a subset of the origColDescs. +func mutationRowIdxToReturnIdx(origColDescs, returnColDescs []sqlbase.ColumnDescriptor) []int { + // Create a ColumnID to index map. + colIDToRetIndex := row.ColIDtoRowIndexFromCols(origColDescs) + + // Initialize the rowIdxToTabColIdx array. + rowIdxToRetIdx := make([]int, len(origColDescs)) + for i := range origColDescs { + // -1 value indicates that this column is not being returned. + rowIdxToRetIdx[i] = -1 + } + + // Set the appropriate index values for the returning columns. + for i := range returnColDescs { + if idx, ok := colIDToRetIndex[returnColDescs[i].ID]; ok { + rowIdxToRetIdx[idx] = i + } + } + + return rowIdxToRetIdx +} + func (ef *execFactory) ConstructInsert( input exec.Node, table cat.Table, insertCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, skipFKChecks bool, ) (exec.Node, error) { // Derive insert table and column descriptors. + rowsNeeded := !returnCols.Empty() tabDesc := table.(*optTable).desc colDescs := makeColDescList(table, insertCols) @@ -1187,18 +1213,33 @@ func (ef *execFactory) ConstructInsert( // Determine the relational type of the generated insert node. // If rows are not needed, no columns are returned. - var returnCols sqlbase.ResultColumns + var returnColumns sqlbase.ResultColumns + var tabIdxToRetIdx []int if rowsNeeded { - // Insert always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnCols) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnColumns = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the tabIdxToReturnColDescs for the mutation. Insert always + // returns non-mutation columns in the same order they are defined in + // the table. + tabIdxToRetIdx = mutationRowIdxToReturnIdx(tabDesc.Columns, returnColDescs) } // Regular path for INSERT. ins := insertNodePool.Get().(*insertNode) *ins = insertNode{ source: input.(planNode), - columns: returnCols, + columns: returnColumns, run: insertRun{ ti: tableInserter{ri: ri}, checkHelper: checkHelper, @@ -1207,7 +1248,8 @@ func (ef *execFactory) ConstructInsert( Cols: tabDesc.Columns, Mapping: ri.InsertColIDtoRowIndex, }, - insertCols: ri.InsertCols, + insertCols: ri.InsertCols, + tabIdxToRetIdx: tabIdxToRetIdx, }, } @@ -1228,10 +1270,11 @@ func (ef *execFactory) ConstructUpdate( table cat.Table, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnCols.Empty() tabDesc := table.(*optTable).desc fetchColDescs := makeColDescList(table, fetchCols) @@ -1285,11 +1328,31 @@ func (ef *execFactory) ConstructUpdate( // Determine the relational type of the generated update node. // If rows are not needed, no columns are returned. - var returnCols sqlbase.ResultColumns + var returnColumns sqlbase.ResultColumns + var rowIdxToRetIdx []int if rowsNeeded { - // Update always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnCols) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the update runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnColumns = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the rowIdxToReturnColDescs for the mutation. Update returns + // the of non-mutation columns specified, in the same order they are + // defined in the table. + // + // The Updater derives/stores the fetch columns of the mutation and + // since the return columns are always a subset of the fetch columns, + // we can use use the fetch columns generate the mapping for the + // returning rows. + rowIdxToRetIdx = mutationRowIdxToReturnIdx(ru.FetchCols, returnColDescs) } // updateColsIdx inverts the mapping of UpdateCols to FetchCols. See @@ -1303,7 +1366,7 @@ func (ef *execFactory) ConstructUpdate( upd := updateNodePool.Get().(*updateNode) *upd = updateNode{ source: input.(planNode), - columns: returnCols, + columns: returnColumns, run: updateRun{ tu: tableUpdater{ru: ru}, checkHelper: checkHelper, @@ -1313,9 +1376,10 @@ func (ef *execFactory) ConstructUpdate( Cols: ru.FetchCols, Mapping: ru.FetchColIDtoRowIndex, }, - sourceSlots: sourceSlots, - updateValues: make(tree.Datums, len(ru.UpdateCols)), - updateColsIdx: updateColsIdx, + sourceSlots: sourceSlots, + updateValues: make(tree.Datums, len(ru.UpdateCols)), + updateColsIdx: updateColsIdx, + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -1355,10 +1419,11 @@ func (ef *execFactory) ConstructUpsert( insertCols exec.ColumnOrdinalSet, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnCols.Empty() tabDesc := table.(*optTable).desc insertColDescs := makeColDescList(table, insertCols) fetchColDescs := makeColDescList(table, fetchCols) @@ -1405,11 +1470,27 @@ func (ef *execFactory) ConstructUpsert( // Determine the relational type of the generated upsert node. // If rows are not needed, no columns are returned. - var returnCols sqlbase.ResultColumns + var returnColumns sqlbase.ResultColumns + var returnColDescs []sqlbase.ColumnDescriptor + var tabIdxToRetIdx []int if rowsNeeded { - // Upsert always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs = makeColDescList(table, returnCols) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnColumns = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the tabIdxToReturnColDescs for the mutation. Upsert returns + // non-mutation columns specified, in the same order they are defined + // in the table. + tabIdxToRetIdx = mutationRowIdxToReturnIdx(tabDesc.Columns, returnColDescs) } // updateColsIdx inverts the mapping of UpdateCols to FetchCols. See @@ -1424,7 +1505,7 @@ func (ef *execFactory) ConstructUpsert( ups := upsertNodePool.Get().(*upsertNode) *ups = upsertNode{ source: input.(planNode), - columns: returnCols, + columns: returnColumns, run: upsertRun{ checkHelper: checkHelper, insertCols: ri.InsertCols, @@ -1438,11 +1519,13 @@ func (ef *execFactory) ConstructUpsert( alloc: &ef.planner.alloc, collectRows: rowsNeeded, }, - canaryOrdinal: int(canaryCol), - fkTables: fkTables, - fetchCols: fetchColDescs, - updateCols: updateColDescs, - ru: ru, + canaryOrdinal: int(canaryCol), + fkTables: fkTables, + fetchCols: fetchColDescs, + updateCols: updateColDescs, + returnCols: returnColDescs, + ru: ru, + tabIdxToRetIdx: tabIdxToRetIdx, }, }, } @@ -1459,10 +1542,61 @@ func (ef *execFactory) ConstructUpsert( return &rowCountNode{source: ups}, nil } +// colsRequiredForDelete returns all the columns required to perform a delete +// of a row on the table. This will include the returnColDescs columns that +// are referenced in the RETURNING clause of the delete mutation. This +// is different from the fetch columns of the delete mutation as it doesn't +// include columns that are not part of index keys or the RETURNING columns. +func colsRequiredForDelete( + table cat.Table, tableColDescs, returnColDescs []sqlbase.ColumnDescriptor, +) []sqlbase.ColumnDescriptor { + // Find all the columns that are part of the rows returned by the delete. + deleteDescs := make([]sqlbase.ColumnDescriptor, 0, len(tableColDescs)) + colMap := make(map[sqlbase.ColumnID]struct{}, len(tableColDescs)) + for i := 0; i < table.IndexCount(); i++ { + index := table.Index(i) + for j := 0; j < index.KeyColumnCount(); j++ { + col := *index.Column(j).Column.(*sqlbase.ColumnDescriptor) + if _, ok := colMap[col.ID]; ok { + continue + } + + deleteDescs = append(deleteDescs, col) + colMap[col.ID] = struct{}{} + } + } + + // Add columns specified in the RETURNING clasue. + for _, col := range returnColDescs { + if _, ok := colMap[col.ID]; ok { + continue + } + + deleteDescs = append(deleteDescs, col) + colMap[col.ID] = struct{}{} + } + + // The order of the columns processed by the delete must be in the order they + // are present in the table. + tabDescs := make([]sqlbase.ColumnDescriptor, 0, len(deleteDescs)) + for i := 0; i < len(tableColDescs); i++ { + col := tableColDescs[i] + if _, ok := colMap[col.ID]; ok { + tabDescs = append(tabDescs, col) + } + } + + return tabDescs +} + func (ef *execFactory) ConstructDelete( - input exec.Node, table cat.Table, fetchCols exec.ColumnOrdinalSet, rowsNeeded bool, + input exec.Node, + table cat.Table, + fetchCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnCols.Empty() tabDesc := table.(*optTable).desc fetchColDescs := makeColDescList(table, fetchCols) @@ -1501,21 +1635,43 @@ func (ef *execFactory) ConstructDelete( // Determine the relational type of the generated delete node. // If rows are not needed, no columns are returned. - var returnCols sqlbase.ResultColumns + var returnColumns sqlbase.ResultColumns + var rowIdxToRetIdx []int if rowsNeeded { - // Delete always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnCols) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + // Update the tabIdxToReturnColDescs for the mutation. Delete returns + // the non-mutation columns specified, in the same order they are defined + // in the table. + returnColumns = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Find all the columns that the Deleter returns. The returning columns + // of the mutation are a subset of this column set and will use this + // for the return mapping. + requiredDeleteColumns := colsRequiredForDelete(table, tabDesc.Columns, returnColDescs) + + // Update the rowIdxToReturnIdx for the mutation. + rowIdxToRetIdx = mutationRowIdxToReturnIdx(requiredDeleteColumns, returnColDescs) } // Now make a delete node. We use a pool. del := deleteNodePool.Get().(*deleteNode) *del = deleteNode{ source: input.(planNode), - columns: returnCols, + columns: returnColumns, run: deleteRun{ - td: tableDeleter{rd: rd, alloc: &ef.planner.alloc}, - rowsNeeded: rowsNeeded, + td: tableDeleter{rd: rd, alloc: &ef.planner.alloc}, + rowsNeeded: rowsNeeded, + rowIdxToRetIdx: rowIdxToRetIdx, }, } diff --git a/pkg/sql/rowcontainer/datum_row_container.go b/pkg/sql/rowcontainer/datum_row_container.go index 6460e58975c9..9944c36f42da 100644 --- a/pkg/sql/rowcontainer/datum_row_container.go +++ b/pkg/sql/rowcontainer/datum_row_container.go @@ -244,6 +244,11 @@ func (c *RowContainer) Len() int { return c.numRows } +// NumCols reports the number of columns for each row in the container. +func (c *RowContainer) NumCols() int { + return c.numCols +} + // At accesses a row at a specific index. func (c *RowContainer) At(i int) tree.Datums { // This is a hot-path: do not add additional checks here. diff --git a/pkg/sql/tablewriter_upsert_opt.go b/pkg/sql/tablewriter_upsert_opt.go index 64121acbfd3c..2cca87ce6c69 100644 --- a/pkg/sql/tablewriter_upsert_opt.go +++ b/pkg/sql/tablewriter_upsert_opt.go @@ -15,9 +15,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/internal/client" "github.com/cockroachdb/cockroach/pkg/sql/row" + "github.com/cockroachdb/cockroach/pkg/sql/rowcontainer" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" - "github.com/cockroachdb/errors" ) // optTableUpserter implements the upsert operation when it is planned by the @@ -53,6 +53,9 @@ type optTableUpserter struct { // updateCols indicate which columns need an update during a conflict. updateCols []sqlbase.ColumnDescriptor + // returnCols indicate which columns need to be returned by the Upsert. + returnCols []sqlbase.ColumnDescriptor + // canaryOrdinal is the ordinal position of the column within the input row // that is used to decide whether to execute an insert or update operation. // If the canary column is null, then an insert will be performed; otherwise, @@ -67,6 +70,14 @@ type optTableUpserter struct { // ru is used when updating rows. ru row.Updater + + // tabIdxToRetIdx is the mapping from the columns in the table to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the table column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value an the i-th index refers to the + // index of the resultRowBuffer where the i-th column of the table is + // to be returned. + tabIdxToRetIdx []int } // init is part of the tableWriter interface. @@ -77,7 +88,12 @@ func (tu *optTableUpserter) init(txn *client.Txn, evalCtx *tree.EvalContext) err } if tu.collectRows { - tu.resultRow = make(tree.Datums, len(tu.colIDToReturnIndex)) + tu.resultRow = make(tree.Datums, len(tu.returnCols)) + tu.rowsUpserted = rowcontainer.NewRowContainer( + evalCtx.Mon.MakeBoundAccount(), + sqlbase.ColTypeInfoFromColDescs(tu.returnCols), + tu.insertRows.Len(), + ) } tu.ru, err = row.MakeUpdater( @@ -161,12 +177,27 @@ func (tu *optTableUpserter) insertNonConflictingRow( // Reshape the row if needed. if tu.insertReorderingRequired { - resultRow := tu.makeResultFromRow(insertRow, tu.ri.InsertColIDtoRowIndex) - _, err := tu.rowsUpserted.AddRow(ctx, resultRow) + tableRow := tu.makeResultFromRow(insertRow, tu.ri.InsertColIDtoRowIndex) + + // TODO(ridwanmsharif): Why didn't they update the value of tu.resultRow + // before? Is it safe to be doing it now? + // Map the upserted columns into the result row before adding it. + for tabIdx := range tableRow { + if retIdx := tu.tabIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = tableRow[tabIdx] + } + } + _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } - _, err := tu.rowsUpserted.AddRow(ctx, insertRow) + // Map the upserted columns into the result row before adding it. + for tabIdx := range insertRow { + if retIdx := tu.tabIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = insertRow[tabIdx] + } + } + _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } @@ -208,22 +239,30 @@ func (tu *optTableUpserter) updateConflictingRow( return nil } - // We now need a row that has the shape of the result row. + // We now need a row that has the shape of the result row with + // the appropriate return columns. Make sure all the fetch columns + // are present. + tableRow := tu.makeResultFromRow(fetchRow, tu.ru.FetchColIDtoRowIndex) + + // Make sure all the updated columns are present. for colID, returnIndex := range tu.colIDToReturnIndex { // If an update value for a given column exists, use that; else use the - // existing value of that column. + // existing value of that column if it has been fetched. rowIndex, ok := tu.ru.UpdateColIDtoRowIndex[colID] if ok { - tu.resultRow[returnIndex] = updateValues[rowIndex] - } else { - rowIndex, ok = tu.ru.FetchColIDtoRowIndex[colID] - if !ok { - return errors.AssertionFailedf("no existing value is available for column") - } - tu.resultRow[returnIndex] = fetchRow[rowIndex] + tableRow[returnIndex] = updateValues[rowIndex] + } + } + + // Map the upserted columns into the result row before adding it. + for tabIdx := range tableRow { + if retIdx := tu.tabIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = tableRow[tabIdx] } } + // The resulting row may have nil values for columns that aren't + // being upserted, updated or fetched. _, err = tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } diff --git a/pkg/sql/update.go b/pkg/sql/update.go index 778476c9c0f8..3807353ee830 100644 --- a/pkg/sql/update.go +++ b/pkg/sql/update.go @@ -382,6 +382,12 @@ func (p *planner) Update( updateColsIdx[id] = i } + // Since all columns are being returned, use the 1:1 mapping. + rowIdxToRetIdx := make([]int, len(desc.Columns)) + for i := range rowIdxToRetIdx { + rowIdxToRetIdx[i] = i + } + un := updateNodePool.Get().(*updateNode) *un = updateNode{ source: rows, @@ -397,9 +403,10 @@ func (p *planner) Update( Cols: desc.Columns, Mapping: ru.FetchColIDtoRowIndex, }, - sourceSlots: sourceSlots, - updateValues: make(tree.Datums, len(ru.UpdateCols)), - updateColsIdx: updateColsIdx, + sourceSlots: sourceSlots, + updateValues: make(tree.Datums, len(ru.UpdateCols)), + updateColsIdx: updateColsIdx, + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -480,6 +487,13 @@ type updateRun struct { // This provides the inverse mapping of sourceSlots. // updateColsIdx map[sqlbase.ColumnID]int + + // rowIdxToRetIdx is the mapping from the columns in ru.FetchCols to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value an the i-th index refers to the + // index of the resultRowBuffer where the i-th column is to be returned. + rowIdxToRetIdx []int } // maxUpdateBatchSize is the max number of entries in the KV batch for @@ -701,7 +715,14 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) // // MakeUpdater guarantees that the first columns of the new values // are those specified u.columns. - resultValues := newValues[:len(u.columns)] + resultValues := make([]tree.Datum, len(u.columns)) + for i := range u.run.rowIdxToRetIdx { + retIdx := u.run.rowIdxToRetIdx[i] + if retIdx >= 0 { + resultValues[retIdx] = newValues[i] + } + } + if _, err := u.run.rows.AddRow(params.ctx, resultValues); err != nil { return err }