Skip to content

Commit

Permalink
[fix](Nereids) should not replace slot by Alias when do NormalizeSlot (
Browse files Browse the repository at this point in the history
…apache#24928)

when we do NormalizeToSlot, we pushed complex expression and only remain
slot of it. When we do this, we collect alias and their child and
compute its child in bottom project, remain the result slot in current
node. for example

Window(max(...), c1 as a1)

after normalization, we get

Window(max(...), a1)
+-- Project(..., c1 as a1)

But, in some cases, we remove some SlotReference by mistake, for example

Window(max(...), c1, c1 as a1)

after normalization, we get

Window(max(...), a1)
+-- Project(..., c1 as a1)

we lost the SlotReference c1. This PR fix this problem. After this Pr,
we get

Window(max(...), c1, a1)
+-- Project(..., c1, c1 as a1)
  • Loading branch information
morrySnow authored and vinlee19 committed Oct 7, 2023
1 parent b3c932a commit ee61ff5
Show file tree
Hide file tree
Showing 32 changed files with 562 additions and 502 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

Expand Down Expand Up @@ -69,8 +70,14 @@ public static NormalizeToSlotContext buildContext(
if (normalizeToSlotMap.containsKey(expression)) {
continue;
}
NormalizeToSlotTriplet normalizeToSlotTriplet =
NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression));
Alias alias = null;
// consider projects: c1, c1 as a1. we should push down both of them,
// so we could not replace c1 with c1 as a1.
// use null as alias for SlotReference to avoid replace it by another alias of it.
if (!(expression instanceof SlotReference)) {
alias = existsAliasMap.get(expression);
}
NormalizeToSlotTriplet normalizeToSlotTriplet = NormalizeToSlotTriplet.toTriplet(expression, alias);
normalizeToSlotMap.put(expression, normalizeToSlotTriplet);
}
return new NormalizeToSlotContext(normalizeToSlotMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ public void testHavingGroupBySlot() {
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))
).when(FieldChecker.check("projects", ImmutableList.of(new Alias(new ExprId(3), a1, value.toSql()))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));

sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0";
Expand All @@ -113,7 +114,8 @@ public void testHavingGroupBySlot() {
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))
).when(FieldChecker.check("projects", ImmutableList.of(new Alias(new ExprId(3), a1, value.toSql()))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));

sql = "SELECT sum(a2) FROM t1 GROUP BY a1 HAVING a1 > 0";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,15 @@ public void inferPredicatesTest12() {
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("id > 1")),
logicalAggregate(
logicalProject(
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("sid > 1"))
))
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("sid > 1"))
)
)
)
)
)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.types.StringType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Set;

public class NormalizeToSlotTest {

@Test
void testSlotReferenceWithItsAlias() {
SlotReference slotReference = new SlotReference("c1", StringType.INSTANCE);
Alias alias = new Alias(slotReference, "a1");
Set<Alias> existsAliases = ImmutableSet.of(alias);
List<Expression> sourceExpressions = ImmutableList.of(slotReference, alias);

NormalizeToSlotContext context = NormalizeToSlotContext.buildContext(existsAliases, sourceExpressions);
Assertions.assertEquals(slotReference, context.normalizeToUseSlotRef(slotReference));
Assertions.assertEquals(alias.toSlot(), context.normalizeToUseSlotRef(alias));
Assertions.assertEquals(Sets.newHashSet(sourceExpressions),
context.pushDownToNamedExpression(sourceExpressions));
}
}
21 changes: 11 additions & 10 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
-- !ds_shape_1 --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----hashAgg[GLOBAL]
------PhysicalDistribute
--------hashAgg[LOCAL]
----------PhysicalProject
------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_returned_date_sk = date_dim.d_date_sk))otherCondition=()
--------------PhysicalProject
----------------PhysicalOlapScan[store_returns]
--------------PhysicalDistribute
----PhysicalProject
------hashAgg[GLOBAL]
--------PhysicalDistribute
----------hashAgg[LOCAL]
------------PhysicalProject
--------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_returned_date_sk = date_dim.d_date_sk))otherCondition=()
----------------PhysicalProject
------------------filter((date_dim.d_year = 2000))
--------------------PhysicalOlapScan[date_dim]
------------------PhysicalOlapScan[store_returns]
----------------PhysicalDistribute
------------------PhysicalProject
--------------------filter((date_dim.d_year = 2000))
----------------------PhysicalOlapScan[date_dim]
--PhysicalResultSink
----PhysicalTopN
------PhysicalDistribute
Expand Down
59 changes: 30 additions & 29 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query19.out
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,37 @@ PhysicalResultSink
--PhysicalTopN
----PhysicalDistribute
------PhysicalTopN
--------hashAgg[GLOBAL]
----------PhysicalDistribute
------------hashAgg[LOCAL]
--------------PhysicalProject
----------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk))otherCondition=(( not (substring(ca_zip, 1, 5) = substring(s_zip, 1, 5))))
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN] hashCondition=((customer.c_current_addr_sk = customer_address.ca_address_sk))otherCondition=()
----------------------PhysicalProject
------------------------PhysicalOlapScan[customer_address]
----------------------PhysicalDistribute
--------PhysicalProject
----------hashAgg[GLOBAL]
------------PhysicalDistribute
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk))otherCondition=(( not (substring(ca_zip, 1, 5) = substring(s_zip, 1, 5))))
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN] hashCondition=((customer.c_current_addr_sk = customer_address.ca_address_sk))otherCondition=()
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk))otherCondition=()
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[customer]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------hashJoin[INNER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk))otherCondition=()
----------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
------------------------------------PhysicalProject
--------------------------------------PhysicalOlapScan[store_sales]
--------------------------PhysicalOlapScan[customer_address]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk))otherCondition=()
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[customer]
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------hashJoin[INNER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk))otherCondition=()
------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
--------------------------------------PhysicalProject
----------------------------------------PhysicalOlapScan[store_sales]
--------------------------------------PhysicalDistribute
----------------------------------------PhysicalProject
------------------------------------------filter((item.i_manager_id = 2))
--------------------------------------------PhysicalOlapScan[item]
------------------------------------PhysicalDistribute
--------------------------------------PhysicalProject
----------------------------------------filter((item.i_manager_id = 2))
------------------------------------------PhysicalOlapScan[item]
----------------------------------PhysicalDistribute
------------------------------------PhysicalProject
--------------------------------------filter((date_dim.d_moy = 12) and (date_dim.d_year = 1999))
----------------------------------------PhysicalOlapScan[date_dim]
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------PhysicalOlapScan[store]
----------------------------------------filter((date_dim.d_moy = 12) and (date_dim.d_year = 1999))
------------------------------------------PhysicalOlapScan[date_dim]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------PhysicalOlapScan[store]

43 changes: 21 additions & 22 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query27.out
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,29 @@ PhysicalResultSink
----------hashAgg[GLOBAL]
------------PhysicalDistribute
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------PhysicalRepeat
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk))otherCondition=()
------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk))otherCondition=()
--------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_cdemo_sk = customer_demographics.cd_demo_sk))otherCondition=()
----------------------------------PhysicalProject
------------------------------------PhysicalOlapScan[store_sales]
----------------------------------PhysicalDistribute
------------------------------------PhysicalProject
--------------------------------------filter((customer_demographics.cd_education_status = 'Secondary') and (customer_demographics.cd_gender = 'F') and (customer_demographics.cd_marital_status = 'D'))
----------------------------------------PhysicalOlapScan[customer_demographics]
----------------PhysicalRepeat
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk))otherCondition=()
----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk))otherCondition=()
------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_cdemo_sk = customer_demographics.cd_demo_sk))otherCondition=()
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[store_sales]
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter((date_dim.d_year = 1999))
--------------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[item]
------------------------------------filter((customer_demographics.cd_education_status = 'Secondary') and (customer_demographics.cd_gender = 'F') and (customer_demographics.cd_marital_status = 'D'))
--------------------------------------PhysicalOlapScan[customer_demographics]
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter((date_dim.d_year = 1999))
------------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter(s_state IN ('AL', 'LA', 'MI', 'MO', 'SC', 'TN'))
------------------------------PhysicalOlapScan[store]
----------------------------PhysicalOlapScan[item]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------filter(s_state IN ('AL', 'LA', 'MI', 'MO', 'SC', 'TN'))
----------------------------PhysicalOlapScan[store]

33 changes: 17 additions & 16 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query3.out
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@ PhysicalResultSink
--PhysicalTopN
----PhysicalDistribute
------PhysicalTopN
--------hashAgg[GLOBAL]
----------PhysicalDistribute
------------hashAgg[LOCAL]
--------------PhysicalProject
----------------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = store_sales.ss_sold_date_sk))otherCondition=()
------------------PhysicalDistribute
--------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
----------------------PhysicalProject
------------------------PhysicalOlapScan[store_sales]
----------------------PhysicalDistribute
--------PhysicalProject
----------hashAgg[GLOBAL]
------------PhysicalDistribute
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = store_sales.ss_sold_date_sk))otherCondition=()
--------------------PhysicalDistribute
----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk))otherCondition=()
------------------------PhysicalProject
--------------------------filter((item.i_manufact_id = 816))
----------------------------PhysicalOlapScan[item]
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------filter((dt.d_moy = 11))
------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalOlapScan[store_sales]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter((item.i_manufact_id = 816))
------------------------------PhysicalOlapScan[item]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------filter((dt.d_moy = 11))
--------------------------PhysicalOlapScan[date_dim]

33 changes: 17 additions & 16 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
-- !ds_shape_30 --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----hashAgg[GLOBAL]
------PhysicalDistribute
--------hashAgg[LOCAL]
----------PhysicalProject
------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_returning_addr_sk = customer_address.ca_address_sk))otherCondition=()
--------------PhysicalDistribute
----------------PhysicalProject
------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_returned_date_sk = date_dim.d_date_sk))otherCondition=()
--------------------PhysicalProject
----------------------PhysicalOlapScan[web_returns]
--------------------PhysicalDistribute
----PhysicalProject
------hashAgg[GLOBAL]
--------PhysicalDistribute
----------hashAgg[LOCAL]
------------PhysicalProject
--------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_returning_addr_sk = customer_address.ca_address_sk))otherCondition=()
----------------PhysicalDistribute
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_returned_date_sk = date_dim.d_date_sk))otherCondition=()
----------------------PhysicalProject
------------------------filter((date_dim.d_year = 2002))
--------------------------PhysicalOlapScan[date_dim]
--------------PhysicalDistribute
----------------PhysicalProject
------------------PhysicalOlapScan[customer_address]
------------------------PhysicalOlapScan[web_returns]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------filter((date_dim.d_year = 2002))
----------------------------PhysicalOlapScan[date_dim]
----------------PhysicalDistribute
------------------PhysicalProject
--------------------PhysicalOlapScan[customer_address]
--PhysicalResultSink
----PhysicalTopN
------PhysicalDistribute
Expand Down
Loading

0 comments on commit ee61ff5

Please sign in to comment.