Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.rules.expression.rules.AddMinMax.MinMaxValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.DiscreteValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDescVisitor;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
Expand All @@ -39,14 +45,15 @@
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.NotImplementedException;

import java.util.List;
import java.util.Map;
Expand All @@ -63,13 +70,14 @@
* a between 10 and 20 and b between 10 and 20 or a between 100 and 200 and b between 100 and 200
* => (a <= 20 and b <= 20 or a >= 100 and b >= 100) and a >= 10 and a <= 200 and b >= 10 and b <= 200
*/
public class AddMinMax implements ExpressionPatternRuleFactory {
public class AddMinMax implements ExpressionPatternRuleFactory, ValueDescVisitor<Map<Expression, MinMaxValue>, Void> {
public static final AddMinMax INSTANCE = new AddMinMax();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(CompoundPredicate.class)
.whenCtx(ctx -> PlanUtils.isConditionExpressionPlan(ctx.rewriteContext.plan.orElse(null)))
.thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext))
.toRule(ExpressionRuleType.ADD_MIN_MAX)
);
Expand All @@ -78,7 +86,7 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
/** rewrite */
public Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) {
ValueDesc valueDesc = (new RangeInference()).getValue(expr, context);
Map<Expression, MinMaxValue> exprMinMaxValues = getExprMinMaxValues(valueDesc);
Map<Expression, MinMaxValue> exprMinMaxValues = valueDesc.accept(this, null);
removeUnnecessaryMinMaxValues(expr, exprMinMaxValues);
if (!exprMinMaxValues.isEmpty()) {
return addExprMinMaxValues(expr, context, exprMinMaxValues);
Expand All @@ -92,7 +100,8 @@ private enum MatchMinMax {
MATCH_NONE,
}

private static class MinMaxValue {
/** record each expression's min and max value */
public static class MinMaxValue {
// min max range, if range = null means empty
Range<ComparableLiteral> range;

Expand Down Expand Up @@ -280,21 +289,8 @@ private boolean isExprNeedAddMinMax(Expression expr) {
return (expr instanceof SlotReference) && ((SlotReference) expr).getOriginalColumn().isPresent();
}

private Map<Expression, MinMaxValue> getExprMinMaxValues(ValueDesc value) {
if (value instanceof EmptyValue) {
return getExprMinMaxValues((EmptyValue) value);
} else if (value instanceof DiscreteValue) {
return getExprMinMaxValues((DiscreteValue) value);
} else if (value instanceof RangeValue) {
return getExprMinMaxValues((RangeValue) value);
} else if (value instanceof UnknownValue) {
return getExprMinMaxValues((UnknownValue) value);
} else {
throw new NotImplementedException("not implements");
}
}

private Map<Expression, MinMaxValue> getExprMinMaxValues(EmptyValue value) {
@Override
public Map<Expression, MinMaxValue> visitEmptyValue(EmptyValue value, Void context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
Expand All @@ -303,7 +299,8 @@ private Map<Expression, MinMaxValue> getExprMinMaxValues(EmptyValue value) {
return exprMinMaxValues;
}

private Map<Expression, MinMaxValue> getExprMinMaxValues(DiscreteValue value) {
@Override
public Map<Expression, MinMaxValue> visitDiscreteValue(DiscreteValue value, Void context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
Expand All @@ -312,7 +309,23 @@ private Map<Expression, MinMaxValue> getExprMinMaxValues(DiscreteValue value) {
return exprMinMaxValues;
}

private Map<Expression, MinMaxValue> getExprMinMaxValues(RangeValue value) {
@Override
public Map<Expression, MinMaxValue> visitNotDiscreteValue(NotDiscreteValue value, Void context) {
return ImmutableMap.of();
}

@Override
public Map<Expression, MinMaxValue> visitIsNullValue(IsNullValue value, Void context) {
return ImmutableMap.of();
}

@Override
public Map<Expression, MinMaxValue> visitIsNotNullValue(IsNotNullValue value, Void context) {
return ImmutableMap.of();
}

@Override
public Map<Expression, MinMaxValue> visitRangeValue(RangeValue value, Void context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
Expand All @@ -321,16 +334,14 @@ private Map<Expression, MinMaxValue> getExprMinMaxValues(RangeValue value) {
return exprMinMaxValues;
}

private Map<Expression, MinMaxValue> getExprMinMaxValues(UnknownValue valueDesc) {
@Override
public Map<Expression, MinMaxValue> visitCompoundValue(CompoundValue valueDesc, Void context) {
List<ValueDesc> sourceValues = valueDesc.getSourceValues();
if (sourceValues.isEmpty()) {
return Maps.newHashMap();
}
Map<Expression, MinMaxValue> result = Maps.newHashMap(getExprMinMaxValues(sourceValues.get(0)));
Map<Expression, MinMaxValue> result = Maps.newHashMap(sourceValues.get(0).accept(this, context));
int nextExprOrderIndex = result.values().stream().mapToInt(k -> k.exprOrderIndex).max().orElse(0);
for (int i = 1; i < sourceValues.size(); i++) {
// process in sourceValues[i]
Map<Expression, MinMaxValue> minMaxValues = getExprMinMaxValues(sourceValues.get(i));
Map<Expression, MinMaxValue> minMaxValues = sourceValues.get(i).accept(this, context);
// merge values of sourceValues[i] into result.
// also keep the value's relative order in sourceValues[i].
// for example, if a and b in sourceValues[i], but not in result, then during merging,
Expand Down Expand Up @@ -398,4 +409,9 @@ private Map<Expression, MinMaxValue> getExprMinMaxValues(UnknownValue valueDesc)
}
return result;
}

@Override
public Map<Expression, MinMaxValue> visitUnknownValue(UnknownValue valueDesc, Void context) {
return ImmutableMap.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -67,18 +66,17 @@ protected Expression rewrite(Expression expression, ExpressionRewriteContext con
return expression.accept(this, rootIsCondition(context));
}

// for the expression root, only filter and join expression can treat as condition
protected boolean rootIsCondition(ExpressionRewriteContext context) {
Plan plan = context.plan.orElse(null);
if (plan instanceof LogicalFilter || plan instanceof LogicalHaving) {
return true;
} else if (plan instanceof LogicalJoin) {
if (plan instanceof LogicalJoin) {
// null aware join can not treat null as false
ExpressionSource source = context.source.orElse(null);
return ((LogicalJoin<?, ?>) plan).getJoinType() != JoinType.NULL_AWARE_LEFT_ANTI_JOIN
&& (source == ExpressionSource.JOIN_HASH_CONDITION
|| source == ExpressionSource.JOIN_OTHER_CONDITION);
} else {
return false;
return PlanUtils.isConditionExpressionPlan(plan);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,42 +139,30 @@ private Optional<Expression> simplifySafeEqualTrue(Expression expression) {
List<Expression> oldChildren = expression.children();
ImmutableList.Builder<Expression> newChildrenBuilder
= ImmutableList.builderWithExpectedSize(oldChildren.size());
boolean changed = false;
for (Expression child : expression.children()) {
// rewrite child to child <=> TRUE
Expression newChild = simplifySafeEqualTrue(child)
.orElse(new NullSafeEqual(child, BooleanLiteral.TRUE));
if (newChild.getClass() == expression.getClass()) {
// flatten
newChildrenBuilder.addAll(newChild.children());
changed = true;
} else {
changed = changed || child != newChild;
newChildrenBuilder.add(newChild);
}
}
List<Expression> newChildren = newChildrenBuilder.build();
boolean changed = newChildren.size() != oldChildren.size();
if (newChildren.size() == oldChildren.size()) {
for (int i = 0; i < newChildren.size(); i++) {
if (newChildren.get(i) != oldChildren.get(i)) {
changed = true;
break;
}
}
}
return Optional.of(changed ? expression.withChildren(newChildren) : expression);
return Optional.of(changed ? expression.withChildren(newChildrenBuilder.build()) : expression);
}
return Optional.empty();
}

private boolean tryProcessPropagateNullable(Expression expression, Set<Expression> conjuncts) {
if (expression.isLiteral()) {
// for propagate nullable function, if any of its child is null literal,
// the fold rule will simplify it to null literal.
// so here no need to handle with the null literal case.
return !expression.isNullLiteral();
if (!expression.nullable()) {
return true;
} else if (expression instanceof SlotReference) {
if (expression.nullable()) {
conjuncts.add(ExpressionUtils.notIsNull(expression));
}
conjuncts.add(ExpressionUtils.notIsNull(expression));
return true;
} else if (expression instanceof PropagateNullable) {
for (Expression child : expression.children()) {
Expand Down
Loading
Loading