Skip to content

Commit

Permalink
planner: support predicate pushdown for CTE (#33158)
Browse files Browse the repository at this point in the history
close #28163
  • Loading branch information
wjhuang2016 authored Mar 17, 2022
1 parent dcafe8e commit af1ea80
Show file tree
Hide file tree
Showing 9 changed files with 772 additions and 148 deletions.
54 changes: 54 additions & 0 deletions cmd/explaintest/r/agg_predicate_pushdown.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
drop database if exists agg_predicate_pushdown;
create database agg_predicate_pushdown;
create table t(a int, b int, c int);
desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1) and (a > 2) and 1 and (b > 2) and (avg(c) > 3);
id estRows task access object operator info
Projection 711.11 root test.t.a, test.t.b, Column#5
└─Selection 711.11 root gt(Column#5, 3)
└─HashAgg 888.89 root group by:Column#16, Column#17, Column#18, funcs:avg(Column#13)->Column#5, funcs:firstrow(Column#14)->test.t.a, funcs:firstrow(Column#15)->test.t.b
└─Projection 1111.11 root cast(test.t.c, decimal(15,4) BINARY)->Column#13, test.t.a, test.t.b, test.t.a, test.t.b, test.t.c
└─TableReader 1111.11 root data:Selection
└─Selection 1111.11 cop[tikv] gt(test.t.a, 1), gt(test.t.a, 2), gt(test.t.b, 2)
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 or b > 2) and (a > 2 or b < 1) and 1 and (b > 2) and (avg(c) > 3);
id estRows task access object operator info
Projection 657.65 root test.t.a, test.t.b, Column#5
└─Selection 657.65 root gt(Column#5, 3)
└─HashAgg 822.06 root group by:Column#16, Column#17, Column#18, funcs:avg(Column#13)->Column#5, funcs:firstrow(Column#14)->test.t.a, funcs:firstrow(Column#15)->test.t.b
└─Projection 1027.57 root cast(test.t.c, decimal(15,4) BINARY)->Column#13, test.t.a, test.t.b, test.t.a, test.t.b, test.t.c
└─TableReader 1027.57 root data:Selection
└─Selection 1027.57 cop[tikv] gt(test.t.b, 2), or(gt(test.t.a, 1), gt(test.t.b, 2)), or(gt(test.t.a, 2), lt(test.t.b, 1))
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 and b > 2) or (a > 2 and b < 1) or (b > 2 and avg(c) > 3);
id estRows task access object operator info
Projection 3027.54 root test.t.a, test.t.b, Column#5
└─Selection 3027.54 root or(and(gt(test.t.a, 1), gt(test.t.b, 2)), or(and(gt(test.t.a, 2), lt(test.t.b, 1)), and(gt(test.t.b, 2), gt(Column#5, 3))))
└─HashAgg 3784.43 root group by:Column#16, Column#17, Column#18, funcs:avg(Column#13)->Column#5, funcs:firstrow(Column#14)->test.t.a, funcs:firstrow(Column#15)->test.t.b
└─Projection 4730.53 root cast(test.t.c, decimal(15,4) BINARY)->Column#13, test.t.a, test.t.b, test.t.a, test.t.b, test.t.c
└─TableReader 4730.53 root data:Selection
└─Selection 4730.53 cop[tikv] or(and(gt(test.t.a, 1), gt(test.t.b, 2)), or(and(gt(test.t.a, 2), lt(test.t.b, 1)), gt(test.t.b, 2)))
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 or avg(c) > 1) and (a < 3);
id estRows task access object operator info
Projection 2126.93 root test.t.a, test.t.b, Column#5
└─Selection 2126.93 root or(gt(test.t.a, 1), gt(Column#5, 1))
└─HashAgg 2658.67 root group by:Column#16, Column#17, Column#18, funcs:avg(Column#13)->Column#5, funcs:firstrow(Column#14)->test.t.a, funcs:firstrow(Column#15)->test.t.b
└─Projection 3323.33 root cast(test.t.c, decimal(15,4) BINARY)->Column#13, test.t.a, test.t.b, test.t.a, test.t.b, test.t.c
└─TableReader 3323.33 root data:Selection
└─Selection 3323.33 cop[tikv] lt(test.t.a, 3)
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 and avg(c) > 1) or (a < 3);
id estRows task access object operator info
Projection 6393.60 root test.t.a, test.t.b, Column#5
└─Selection 6393.60 root or(and(gt(test.t.a, 1), gt(Column#5, 1)), lt(test.t.a, 3))
└─HashAgg 7992.00 root group by:Column#16, Column#17, Column#18, funcs:avg(Column#13)->Column#5, funcs:firstrow(Column#14)->test.t.a, funcs:firstrow(Column#15)->test.t.b
└─Projection 9990.00 root cast(test.t.c, decimal(15,4) BINARY)->Column#13, test.t.a, test.t.b, test.t.a, test.t.b, test.t.c
└─TableReader 9990.00 root data:Selection
└─Selection 9990.00 cop[tikv] or(gt(test.t.a, 1), lt(test.t.a, 3))
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
use test;
218 changes: 113 additions & 105 deletions cmd/explaintest/r/cte.result

Large diffs are not rendered by default.

278 changes: 260 additions & 18 deletions cmd/explaintest/r/explain_cte.result

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions cmd/explaintest/t/agg_predicate_pushdown.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
drop database if exists agg_predicate_pushdown;
create database agg_predicate_pushdown;

create table t(a int, b int, c int);

desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1) and (a > 2) and 1 and (b > 2) and (avg(c) > 3);

desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 or b > 2) and (a > 2 or b < 1) and 1 and (b > 2) and (avg(c) > 3);

desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 and b > 2) or (a > 2 and b < 1) or (b > 2 and avg(c) > 3);

desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 or avg(c) > 1) and (a < 3);

desc format='brief' select a, b, avg(c) from t group by a, b, c having
(a > 1 and avg(c) > 1) or (a < 3);

use test;
195 changes: 195 additions & 0 deletions cmd/explaintest/t/explain_cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,198 @@ explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 fro
explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 1) select * from cte1;
explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 100 offset 100) select * from cte1;
explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 0 offset 0) select * from cte1;

# TPC-DS Q11
CREATE TABLE `customer` (
`c_customer_sk` int(11) NOT NULL,
`c_customer_id` char(16) NOT NULL,
`c_current_cdemo_sk` int(11) DEFAULT NULL,
`c_current_hdemo_sk` int(11) DEFAULT NULL,
`c_current_addr_sk` int(11) DEFAULT NULL,
`c_first_shipto_date_sk` int(11) DEFAULT NULL,
`c_first_sales_date_sk` int(11) DEFAULT NULL,
`c_salutation` char(10) DEFAULT NULL,
`c_first_name` char(20) DEFAULT NULL,
`c_last_name` char(30) DEFAULT NULL,
`c_preferred_cust_flag` char(1) DEFAULT NULL,
`c_birth_day` int(11) DEFAULT NULL,
`c_birth_month` int(11) DEFAULT NULL,
`c_birth_year` int(11) DEFAULT NULL,
`c_birth_country` varchar(20) DEFAULT NULL,
`c_login` char(13) DEFAULT NULL,
`c_email_address` char(50) DEFAULT NULL,
`c_last_review_date_sk` int(11) DEFAULT NULL,
PRIMARY KEY (`c_customer_sk`) /*T![clustered_index] NONCLUSTERED */
);
CREATE TABLE `store_sales` (
`ss_sold_date_sk` int(11) DEFAULT NULL,
`ss_sold_time_sk` int(11) DEFAULT NULL,
`ss_item_sk` int(11) NOT NULL,
`ss_customer_sk` int(11) DEFAULT NULL,
`ss_cdemo_sk` int(11) DEFAULT NULL,
`ss_hdemo_sk` int(11) DEFAULT NULL,
`ss_addr_sk` int(11) DEFAULT NULL,
`ss_store_sk` int(11) DEFAULT NULL,
`ss_promo_sk` int(11) DEFAULT NULL,
`ss_ticket_number` int(11) NOT NULL,
`ss_quantity` int(11) DEFAULT NULL,
`ss_wholesale_cost` decimal(7,2) DEFAULT NULL,
`ss_list_price` decimal(7,2) DEFAULT NULL,
`ss_sales_price` decimal(7,2) DEFAULT NULL,
`ss_ext_discount_amt` decimal(7,2) DEFAULT NULL,
`ss_ext_sales_price` decimal(7,2) DEFAULT NULL,
`ss_ext_wholesale_cost` decimal(7,2) DEFAULT NULL,
`ss_ext_list_price` decimal(7,2) DEFAULT NULL,
`ss_ext_tax` decimal(7,2) DEFAULT NULL,
`ss_coupon_amt` decimal(7,2) DEFAULT NULL,
`ss_net_paid` decimal(7,2) DEFAULT NULL,
`ss_net_paid_inc_tax` decimal(7,2) DEFAULT NULL,
`ss_net_profit` decimal(7,2) DEFAULT NULL,
PRIMARY KEY (`ss_item_sk`,`ss_ticket_number`) /*T![clustered_index] NONCLUSTERED */
);
CREATE TABLE `date_dim` (
`d_date_sk` int(11) NOT NULL,
`d_date_id` char(16) NOT NULL,
`d_date` date DEFAULT NULL,
`d_month_seq` int(11) DEFAULT NULL,
`d_week_seq` int(11) DEFAULT NULL,
`d_quarter_seq` int(11) DEFAULT NULL,
`d_year` int(11) DEFAULT NULL,
`d_dow` int(11) DEFAULT NULL,
`d_moy` int(11) DEFAULT NULL,
`d_dom` int(11) DEFAULT NULL,
`d_qoy` int(11) DEFAULT NULL,
`d_fy_year` int(11) DEFAULT NULL,
`d_fy_quarter_seq` int(11) DEFAULT NULL,
`d_fy_week_seq` int(11) DEFAULT NULL,
`d_day_name` char(9) DEFAULT NULL,
`d_quarter_name` char(6) DEFAULT NULL,
`d_holiday` char(1) DEFAULT NULL,
`d_weekend` char(1) DEFAULT NULL,
`d_following_holiday` char(1) DEFAULT NULL,
`d_first_dom` int(11) DEFAULT NULL,
`d_last_dom` int(11) DEFAULT NULL,
`d_same_day_ly` int(11) DEFAULT NULL,
`d_same_day_lq` int(11) DEFAULT NULL,
`d_current_day` char(1) DEFAULT NULL,
`d_current_week` char(1) DEFAULT NULL,
`d_current_month` char(1) DEFAULT NULL,
`d_current_quarter` char(1) DEFAULT NULL,
`d_current_year` char(1) DEFAULT NULL,
PRIMARY KEY (`d_date_sk`) /*T![clustered_index] NONCLUSTERED */
);
CREATE TABLE `web_sales` (
`ws_sold_date_sk` int(11) DEFAULT NULL,
`ws_sold_time_sk` int(11) DEFAULT NULL,
`ws_ship_date_sk` int(11) DEFAULT NULL,
`ws_item_sk` int(11) NOT NULL,
`ws_bill_customer_sk` int(11) DEFAULT NULL,
`ws_bill_cdemo_sk` int(11) DEFAULT NULL,
`ws_bill_hdemo_sk` int(11) DEFAULT NULL,
`ws_bill_addr_sk` int(11) DEFAULT NULL,
`ws_ship_customer_sk` int(11) DEFAULT NULL,
`ws_ship_cdemo_sk` int(11) DEFAULT NULL,
`ws_ship_hdemo_sk` int(11) DEFAULT NULL,
`ws_ship_addr_sk` int(11) DEFAULT NULL,
`ws_web_page_sk` int(11) DEFAULT NULL,
`ws_web_site_sk` int(11) DEFAULT NULL,
`ws_ship_mode_sk` int(11) DEFAULT NULL,
`ws_warehouse_sk` int(11) DEFAULT NULL,
`ws_promo_sk` int(11) DEFAULT NULL,
`ws_order_number` int(11) NOT NULL,
`ws_quantity` int(11) DEFAULT NULL,
`ws_wholesale_cost` decimal(7,2) DEFAULT NULL,
`ws_list_price` decimal(7,2) DEFAULT NULL,
`ws_sales_price` decimal(7,2) DEFAULT NULL,
`ws_ext_discount_amt` decimal(7,2) DEFAULT NULL,
`ws_ext_sales_price` decimal(7,2) DEFAULT NULL,
`ws_ext_wholesale_cost` decimal(7,2) DEFAULT NULL,
`ws_ext_list_price` decimal(7,2) DEFAULT NULL,
`ws_ext_tax` decimal(7,2) DEFAULT NULL,
`ws_coupon_amt` decimal(7,2) DEFAULT NULL,
`ws_ext_ship_cost` decimal(7,2) DEFAULT NULL,
`ws_net_paid` decimal(7,2) DEFAULT NULL,
`ws_net_paid_inc_tax` decimal(7,2) DEFAULT NULL,
`ws_net_paid_inc_ship` decimal(7,2) DEFAULT NULL,
`ws_net_paid_inc_ship_tax` decimal(7,2) DEFAULT NULL,
`ws_net_profit` decimal(7,2) DEFAULT NULL,
PRIMARY KEY (`ws_item_sk`,`ws_order_number`) /*T![clustered_index] NONCLUSTERED */
);
desc format='brief' with year_total as (
select c_customer_id customer_id
,c_first_name customer_first_name
,c_last_name customer_last_name
,c_preferred_cust_flag customer_preferred_cust_flag
,c_birth_country customer_birth_country
,c_login customer_login
,c_email_address customer_email_address
,d_year dyear
,sum(ss_ext_list_price-ss_ext_discount_amt) year_total
,'s' sale_type
from customer
,store_sales
,date_dim
where c_customer_sk = ss_customer_sk
and ss_sold_date_sk = d_date_sk
group by c_customer_id
,c_first_name
,c_last_name
,c_preferred_cust_flag
,c_birth_country
,c_login
,c_email_address
,d_year
union all
select c_customer_id customer_id
,c_first_name customer_first_name
,c_last_name customer_last_name
,c_preferred_cust_flag customer_preferred_cust_flag
,c_birth_country customer_birth_country
,c_login customer_login
,c_email_address customer_email_address
,d_year dyear
,sum(ws_ext_list_price-ws_ext_discount_amt) year_total
,'w' sale_type
from customer
,web_sales
,date_dim
where c_customer_sk = ws_bill_customer_sk
and ws_sold_date_sk = d_date_sk
group by c_customer_id
,c_first_name
,c_last_name
,c_preferred_cust_flag
,c_birth_country
,c_login
,c_email_address
,d_year
)
select
t_s_secyear.customer_id
,t_s_secyear.customer_first_name
,t_s_secyear.customer_last_name
,t_s_secyear.customer_email_address
from year_total t_s_firstyear
,year_total t_s_secyear
,year_total t_w_firstyear
,year_total t_w_secyear
where t_s_secyear.customer_id = t_s_firstyear.customer_id
and t_s_firstyear.customer_id = t_w_secyear.customer_id
and t_s_firstyear.customer_id = t_w_firstyear.customer_id
and t_s_firstyear.sale_type = 's'
and t_w_firstyear.sale_type = 'w'
and t_s_secyear.sale_type = 's'
and t_w_secyear.sale_type = 'w'
and t_s_firstyear.dyear = 2001
and t_s_secyear.dyear = 2001+1
and t_w_firstyear.dyear = 2001
and t_w_secyear.dyear = 2001+1
and t_s_firstyear.year_total > 0
and t_w_firstyear.year_total > 0
and case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else 0.0 end
> case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else 0.0 end
order by t_s_secyear.customer_id
,t_s_secyear.customer_first_name
,t_s_secyear.customer_last_name
,t_s_secyear.customer_email_address
limit 100;
6 changes: 5 additions & 1 deletion planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3899,11 +3899,15 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName
cte.cteClass = &CTEClass{IsDistinct: cte.isDistinct, seedPartLogicalPlan: cte.seedLP,
recursivePartLogicalPlan: cte.recurLP, IDForStorage: cte.storageID,
optFlag: cte.optFlag, HasLimit: hasLimit, LimitBeg: limitBeg,
LimitEnd: limitEnd}
LimitEnd: limitEnd, pushDownPredicates: make([]expression.Expression, 0), ColumnMap: make(map[string]*expression.Column)}
}
var p LogicalPlan
lp := LogicalCTE{cteAsName: tn.Name, cte: cte.cteClass, seedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset())
prevSchema := cte.seedLP.Schema().Clone()
lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars()))
for i, col := range lp.schema.Columns {
lp.cte.ColumnMap[string(col.HashCode(nil))] = prevSchema.Columns[i]
}
p = lp
p.SetOutputNames(cte.seedLP.OutputNames())
if len(asName.String()) > 0 {
Expand Down
3 changes: 3 additions & 0 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,9 @@ type CTEClass struct {
LimitBeg uint64
LimitEnd uint64
IsInApply bool
// pushDownPredicates may be push-downed by different references.
pushDownPredicates []expression.Expression
ColumnMap map[string]*expression.Column
}

// LogicalCTE is for CTE.
Expand Down
Loading

0 comments on commit af1ea80

Please sign in to comment.