Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import com.google.common.collect.ImmutableList;
import java.util.List;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.plan.RelOptRule;

public class OpenSearchRules {
public static final List<ConverterRule> OPEN_SEARCH_OPT_RULES = ImmutableList.of();
private static final PPLAggregateConvertRule AGGREGATE_CONVERT_RULE =
PPLAggregateConvertRule.Config.SUM_CONVERTER.toRule();

public static final List<RelOptRule> OPEN_SEARCH_OPT_RULES =
ImmutableList.of(AGGREGATE_CONVERT_RULE);

// prevent instantiation
private OpenSearchRules() {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.sql.calcite.plan;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.commons.lang3.tuple.Pair;
import org.immutables.value.Value;

/**
* Planner rule that converts specific aggCall to a more efficient expressions, which includes:
*
* <p>- SUM(FIELD + NUMBER) -> SUM(FIELD) + NUMBER * COUNT()
*
* <p>- SUM(FIELD - NUMBER) -> SUM(FIELD) - NUMBER * COUNT()
*
* <p>- SUM(FIELD * NUMBER) -> SUM(FIELD) * NUMBER
*
* <p>- SUM(FIELD / NUMBER) -> SUM(FIELD) / NUMBER, Don't support this because of precision issue
*
* <p>TODO:
*
* <p>- AVG/MAX/MIN(FIELD [+|-|*|+|/] NUMBER) -> AVG/MAX/MIN(FIELD) [+|-|*|+|/] NUMBER
*/
@Value.Enclosing
public class PPLAggregateConvertRule extends RelRule<PPLAggregateConvertRule.Config> {

/** Creates a OpenSearchAggregateConvertRule. */
protected PPLAggregateConvertRule(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
if (call.rels.length == 2) {
final LogicalAggregate aggregate = call.rel(0);
final LogicalProject project = call.rel(1);
apply(call, aggregate, project);
} else {
throw new AssertionError(
String.format(
"The length of rels should be %s but got %s",
this.operands.size(), call.rels.length));
}
}

public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {

final RelBuilder relBuilder = call.builder();
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
relBuilder.push(project.getInput());

/*
Build new projects with fields to be used in the converted agg call.
Need to build this project in advance since building converted agg call has dependency on it.
*/
List<AggregateCall> aggCalls = aggregate.getAggCallList();
final List<RexNode> newChildProjects = new ArrayList<>(project.getProjects());
List<Integer> convertedAggCallArgs =
aggCalls.stream()
.filter(aggCall -> isConvertableAggCall(aggCall, project))
.map(
aggCall -> {
RexInputRef rexRef =
getFieldAndLiteral(project.getProjects().get(aggCall.getArgList().get(0)))
.getLeft();
// Don't remove elements in the child project since we don't know if it will be
// used by
// other aggCall, will handle unused projects later
int ref = newChildProjects.indexOf(rexRef);
if (ref == -1) {
ref = newChildProjects.size();
newChildProjects.add(rexRef);
}
return ref;
})
.collect(Collectors.toList());
relBuilder.project(newChildProjects);
RelNode newInput = relBuilder.peek();

/* Build converted agg call and its parent projects */
int convertedAggCallCnt = 0;
final int groupSetOffset = aggregate.getGroupSet().cardinality();
final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
final PairList<OperatorConstructor, String> newExprOnAggCall = PairList.of();
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
AggregateCall aggCall = aggregate.getAggCallList().get(i);
if (isConvertableAggCall(aggCall, project)) {
// The arg ref of convertable aggCall starts at the end of the project
int argRef = convertedAggCallArgs.get(convertedAggCallCnt++);
AggregateCall sumCall =
AggregateCall.create(
aggCall.getParserPosition(),
aggCall.getAggregation(),
aggCall.isDistinct(),
aggCall.isApproximate(),
aggCall.ignoreNulls(),
aggCall.rexList,
ImmutableList.of(argRef),
aggCall.filterArg,
aggCall.distinctKeys,
aggCall.collation,
aggregate.getGroupCount(),
newInput, // Note: must be the new Project
null, // The type will be inferred.
aggCall.getName() + "_SUM");
int sumCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, sumCall);

final Function<RelNode, Function<RexNode, RexNode>> literalConverterProvider;
RexCall rexCall = (RexCall) project.getProjects().get(aggCall.getArgList().get(0));
if (rexCall.getOperator().kind == SqlKind.PLUS
|| rexCall.getOperator().kind == SqlKind.MINUS) {
AggregateCall countCall =
AggregateCall.create(
aggCall.getParserPosition(),
SqlStdOperatorTable.COUNT,
aggCall.isDistinct(),
aggCall.isApproximate(),
aggCall.ignoreNulls(),
aggCall.rexList,
ImmutableList.of(argRef),
aggCall.filterArg,
aggCall.distinctKeys,
aggCall.collation,
aggregate.getGroupCount(),
newInput,
null, // The type will be inferred.
aggCall.getName() + "_COUNT");
int countCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, countCall);
literalConverterProvider =
input ->
literal ->
rexBuilder.makeCall(
aggCall.getType(),
SqlStdOperatorTable.MULTIPLY,
List.of(
rexBuilder.makeInputRef(input, groupSetOffset + countCallRef),
literal));
} else {
literalConverterProvider = input -> literal -> literal;
}
newExprOnAggCall.add(
input -> {
Function<RexNode, RexNode> fieldConverter =
field -> rexBuilder.makeInputRef(input, groupSetOffset + sumCallRef);
Function<RexNode, RexNode> literalConverter = literalConverterProvider.apply(input);
List<RexNode> operands =
List.of(
convertToNewOperand(
rexCall.getOperands().get(0), fieldConverter, literalConverter),
convertToNewOperand(
rexCall.getOperands().get(1), fieldConverter, literalConverter));
return rexBuilder.makeCall(aggCall.getType(), rexCall.getOperator(), operands);
},
aggCall.getName());
} else {
int callRef = putToDistinctAggregateCalls(distinctAggregateCalls, aggCall);
newExprOnAggCall.add(
input -> rexBuilder.makeInputRef(input, groupSetOffset + callRef), aggCall.getName());
}
}

/* Eliminate unused fields in the child project */
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
;
ImmutableList<ImmutableBitSet> newGroupSets = aggregate.getGroupSets();
;
final Set<Integer> fieldsUsed =
RelOptUtil.getAllFields2(aggregate.getGroupSet(), distinctAggregateCalls);
if (fieldsUsed.size() < newChildProjects.size()) {
// Some fields are computed but not used. Prune them.
final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>();
for (int source : fieldsUsed) {
sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size());
}
newGroupSet = aggregate.getGroupSet().permute(sourceFieldToTargetFieldMap);
newGroupSets =
ImmutableBitSet.ORDERING.immutableSortedCopy(
ImmutableBitSet.permute(aggregate.getGroupSets(), sourceFieldToTargetFieldMap));
final Mappings.TargetMapping targetMapping =
Mappings.target(sourceFieldToTargetFieldMap, newChildProjects.size(), fieldsUsed.size());
final List<AggregateCall> oldAggregateCalls = new ArrayList<>(distinctAggregateCalls);
distinctAggregateCalls.clear();
for (AggregateCall aggregateCall : oldAggregateCalls) {
distinctAggregateCalls.add(aggregateCall.transform(targetMapping));
}
// Project the used fields
relBuilder.project(relBuilder.fields(new ArrayList<>(fieldsUsed)));
}

/* Build the final project-aggregate-project after eliminating unused fields */
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, newGroupSets), distinctAggregateCalls);
List<RexNode> parentProjects =
new ArrayList<>(relBuilder.fields(IntStream.range(0, groupSetOffset).boxed().collect(
Collectors.toList())));
parentProjects.addAll(
newExprOnAggCall.transform(
(constructor, name) ->
aliasMaybe(relBuilder, constructor.apply(relBuilder.peek()), name)));
relBuilder.project(parentProjects);
call.transformTo(relBuilder.build());
}

interface OperatorConstructor {
RexNode apply(RelNode input);
}

private int putToDistinctAggregateCalls(
List<AggregateCall> distinctAggregateCalls, AggregateCall aggCall) {
int i = distinctAggregateCalls.indexOf(aggCall);
if (i < 0) {
i = distinctAggregateCalls.size();
distinctAggregateCalls.add(aggCall);
}
return i;
}

private boolean isConvertableAggCall(AggregateCall aggCall, Project project) {
return aggCall.getAggregation().getKind() == SqlKind.SUM
&& Config.isCallWithLiteral(project.getProjects().get(aggCall.getArgList().get(0)));
}

private static Pair<RexInputRef, RexLiteral> getFieldAndLiteral(RexNode node) {
RexCall call = (RexCall) node;
RexNode arg1 = call.getOperands().get(0);
RexNode arg2 = call.getOperands().get(1);
return arg1.getKind() == SqlKind.INPUT_REF
? Pair.of((RexInputRef) arg1, (RexLiteral) arg2)
: Pair.of((RexInputRef) arg2, (RexLiteral) arg1);
}

private static RexNode convertToNewOperand(
RexNode operand,
Function<RexNode, RexNode> fieldConverter,
Function<RexNode, RexNode> literalConverter) {
if (operand.getKind() == SqlKind.INPUT_REF) {
return fieldConverter.apply(operand);
} else {
return literalConverter.apply(operand);
}
}

private RexNode aliasMaybe(RelBuilder builder, RexNode node, String alias) {
return alias == null ? node : builder.alias(node, alias);
}

/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config SUM_CONVERTER =
ImmutablePPLAggregateConvertRule.Config.builder()
.build()
.withOperandSupplier(
b0 ->
b0.operand(LogicalAggregate.class)
.predicate(Config::containsSumAggCall)
.oneInput(
b1 ->
b1.operand(LogicalProject.class)
.predicate(Config::containsCallWithNumber)
.anyInputs()));

static boolean containsSumAggCall(LogicalAggregate aggregate) {
return aggregate.getAggCallList().stream()
.anyMatch(aggCall -> aggCall.getAggregation().getKind() == SqlKind.SUM);
}

static boolean containsCallWithNumber(LogicalProject project) {
return project.getProjects().stream().anyMatch(Config::isCallWithLiteral);
}

private static boolean isCallWithLiteral(RexNode node) {
if (CONVERTABLE_FUNCTIONS.contains(node.getKind()) && node instanceof RexCall) {
RexCall call = (RexCall) node;
RexNode arg1 = call.getOperands().get(0);
RexNode arg2 = call.getOperands().get(1);
return (arg1.getKind() == SqlKind.INPUT_REF && arg2.getKind() == SqlKind.LITERAL)
|| (arg1.getKind() == SqlKind.LITERAL && arg2.getKind() == SqlKind.INPUT_REF);
}
return false;
}

List<SqlKind> CONVERTABLE_FUNCTIONS =
List.of(
SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES
// Don't support division because of the issue of integer division
// e.g. (2000 / 3) * 3 = 1998 while 2000 * 3 / 3 = 2000
// SqlKind.DIVIDE
);

@Override
default PPLAggregateConvertRule toRule() {
return new PPLAggregateConvertRule(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.Util;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.plan.OpenSearchRules;
import org.opensearch.sql.calcite.plan.Scannable;
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;

Expand Down Expand Up @@ -231,10 +232,15 @@ public <R> R perform(
final RelOptPlanner planner =
createPlanner(
prepareContext, Contexts.of(prepareContext.config()), config.getCostFactory());
registerCustomizedRules(planner);
final RelOptCluster cluster = createCluster(planner, rexBuilder);
return action.apply(cluster, catalogReader, prepareContext.getRootSchema().plus(), statement);
}

private void registerCustomizedRules(RelOptPlanner planner) {
OpenSearchRules.OPEN_SEARCH_OPT_RULES.forEach(planner::addRule);
}

/**
* Customize CalcitePreparingStmt. Override {@link CalcitePrepareImpl#getPreparingStmt} and
* return {@link OpenSearchCalcitePreparingStmt}
Expand Down
Loading
Loading