Skip to content

Commit

Permalink
[GIE Compiler] Support Case When Expression in Logical and Physical P…
Browse files Browse the repository at this point in the history
…lan (#2918)

<!--
Thanks for your contribution! please review
https://github.com/alibaba/GraphScope/blob/main/CONTRIBUTING.md before
opening an issue.
-->

## What do these changes do?
as titled.

<!-- Please give a short brief about these changes. -->

## Related issue number

<!-- Are there any issues opened that will be resolved by merging this
change? -->

#2686

---------

Co-authored-by: BingqingLyu <bingqing.lbq@alibaba-inc.com>
Co-authored-by: Longbin Lai <longbin.lailb@alibaba-inc.com>
  • Loading branch information
3 people authored Jul 3, 2023
1 parent b82b748 commit 7fc14b5
Show file tree
Hide file tree
Showing 15 changed files with 661 additions and 9 deletions.
23 changes: 23 additions & 0 deletions interactive_engine/compiler/src/main/antlr4/CypherGS.g4
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,34 @@ oC_Atom
| oC_FunctionInvocation
| oC_CountAny
| oC_Parameter
| oC_CaseExpression
;

oC_Parameter
: '$' ( oC_SymbolicName ) ;

oC_CaseExpression
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_InputExpression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_ElseExpression )? SP? END ;

oC_InputExpression
: oC_Expression ;

oC_ElseExpression
: oC_Expression ;

CASE : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

ELSE : ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

END : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ;

oC_CaseAlternative
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;

WHEN : ( 'W' | 'w' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

THEN : ( 'T' | 't' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

oC_CountAny
: ( COUNT SP? '(' SP? '*' SP? ')' )
;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2020 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphscope.common.ir.rex.operator;

import static org.apache.calcite.util.Static.RESOURCE;

import static java.util.Objects.requireNonNull;

import com.alibaba.graphscope.common.ir.rex.RexCallBinding;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlTypeUtil;

import java.util.ArrayList;
import java.util.List;

public class CaseOperator extends SqlOperator {

public CaseOperator(SqlOperandTypeInference operandTypeInference) {
super("CASE", SqlKind.CASE, MDX_PRECEDENCE, true, null, operandTypeInference, null);
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
Preconditions.checkArgument(callBinding instanceof RexCallBinding);
boolean foundNotNull = false;
int operandCount = callBinding.getOperandCount();
for (int i = 0; i < operandCount - 1; ++i) {
RelDataType type = callBinding.getOperandType(i);
if ((i & 1) == 0) { // when expression, should be boolean
if (!SqlTypeUtil.inBooleanFamily(type)) {
if (throwOnFailure) {
throw new IllegalArgumentException(
"Expected a boolean type at operand idx = " + i);
}
return false;
}
} else { // then expression
if (!callBinding.isOperandNull(i, false)) {
foundNotNull = true;
}
}
}

if (operandCount > 2 && !callBinding.isOperandNull(operandCount - 1, false)) {
foundNotNull = true;
}

if (!foundNotNull) {
// according to the sql standard we can not have all of the THEN
// statements and the ELSE returning null
if (throwOnFailure && !callBinding.isTypeCoercionEnabled()) {
throw callBinding.newValidationError(RESOURCE.mustNotNullInElse());
}
return false;
}
return true;
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
return inferTypeFromOperands(opBinding);
}

private static RelDataType inferTypeFromOperands(SqlOperatorBinding opBinding) {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final List<RelDataType> argTypes = opBinding.collectOperandTypes();
assert (argTypes.size() % 2) == 1 : "odd number of arguments expected: " + argTypes.size();
assert argTypes.size() > 1
: "CASE must have more than 1 argument. Given " + argTypes.size() + ", " + argTypes;
List<RelDataType> thenTypes = new ArrayList<>();
for (int j = 1; j < (argTypes.size() - 1); j += 2) {
RelDataType argType = argTypes.get(j);
thenTypes.add(argType);
}

thenTypes.add(Iterables.getLast(argTypes));
return requireNonNull(
typeFactory.leastRestrictive(thenTypes),
() -> "Can't find leastRestrictive type for " + thenTypes);
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return SqlOperandCountRanges.any();
}

@Override
public SqlSyntax getSyntax() {
return SqlSyntax.SPECIAL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import com.alibaba.graphscope.gaia.proto.Common;
import com.alibaba.graphscope.gaia.proto.DataType;
import com.alibaba.graphscope.gaia.proto.OuterExpression;
import com.google.common.base.Preconditions;

import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;

/**
Expand All @@ -41,8 +43,38 @@ public OuterExpression.Expression visitCall(RexCall call) {
if (!this.deep) {
return null;
}
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
SqlOperator operator = call.getOperator();
if (operator.getKind() == SqlKind.CASE) {
return visitCase(call);
} else {
return visitOperator(call);
}
}

private OuterExpression.Expression visitCase(RexCall call) {
OuterExpression.Case.Builder caseBuilder = OuterExpression.Case.newBuilder();
int operandCount = call.getOperands().size();
Preconditions.checkArgument(operandCount > 2 && (operandCount & 1) == 1);
for (int i = 1; i < operandCount - 1; i += 2) {
RexNode whenNode = call.getOperands().get(i - 1);
RexNode thenNode = call.getOperands().get(i);
caseBuilder.addWhenThenExpressions(
OuterExpression.Case.WhenThen.newBuilder()
.setWhenExpression(whenNode.accept(this))
.setThenResultExpression(thenNode.accept(this)));
}
caseBuilder.setElseResultExpression(call.getOperands().get(operandCount - 1).accept(this));
return OuterExpression.Expression.newBuilder()
.addOperators(
OuterExpression.ExprOpr.newBuilder()
.setCase(caseBuilder)
.setNodeType(Utils.protoIrDataType(call.getType(), isColumnId)))
.build();
}

private OuterExpression.Expression visitOperator(RexCall call) {
SqlOperator operator = call.getOperator();
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
// left-associative
if (operator.getLeftPrec() <= operator.getRightPrec()) {
for (int i = 0; i < call.getOperands().size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
public abstract class Utils {
public static final Common.Value protoValue(RexLiteral literal) {
switch (literal.getType().getSqlTypeName()) {
case NULL:
return Common.Value.newBuilder().setNone(Common.None.newBuilder().build()).build();
case BOOLEAN:
return Common.Value.newBuilder().setBoolean((Boolean) literal.getValue()).build();
case INTEGER:
Expand Down Expand Up @@ -176,6 +178,8 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator)

public static final Common.DataType protoBasicDataType(RelDataType basicType) {
switch (basicType.getSqlTypeName()) {
case NULL:
return Common.DataType.NONE;
case BOOLEAN:
return Common.DataType.BOOLEAN;
case INTEGER:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,27 +464,29 @@ private RexNode call_(SqlOperator operator, List<RexNode> operandList) {
throw new UnsupportedOperationException(
"operator " + operator.getKind().name() + " not supported");
}
operandList = inferOperandTypes(operator, operandList);
RexCallBinding callBinding =
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
// check count of operands, if fail throw exceptions
operator.validRexOperands(callBinding.getOperandCount(), Litmus.THROW);
// check type of each operand, if fail throw exceptions
operator.checkOperandTypes(callBinding, true);
// derive type
RelDataType type = operator.inferReturnType(callBinding);
// derive return type
RelDataType returnType = operator.inferReturnType(callBinding);
// derive unknown types of operands
operandList = inferOperandTypes(operator, returnType, operandList);
final RexBuilder builder = cluster.getRexBuilder();
return builder.makeCall(type, operator, operandList);
return builder.makeCall(returnType, operator, operandList);
}

private List<RexNode> inferOperandTypes(SqlOperator operator, List<RexNode> operandList) {
private List<RexNode> inferOperandTypes(
SqlOperator operator, RelDataType returnType, List<RexNode> operandList) {
if (operator.getOperandTypeInference() != null
&& operandList.stream()
.anyMatch((t) -> t.getType().getSqlTypeName() == SqlTypeName.UNKNOWN)) {
RexCallBinding callBinding =
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
RelDataType[] newTypes = callBinding.collectOperandTypes().toArray(new RelDataType[0]);
operator.getOperandTypeInference().inferOperandTypes(callBinding, null, newTypes);
operator.getOperandTypeInference().inferOperandTypes(callBinding, returnType, newTypes);
List<RexNode> typeInferredOperands = new ArrayList<>(operandList.size());
GraphRexBuilder rexBuilder = (GraphRexBuilder) this.getRexBuilder();
for (int i = 0; i < operandList.size(); ++i) {
Expand All @@ -507,7 +509,8 @@ private boolean isCurrentSupported(SqlOperator operator) {
|| sqlKind == SqlKind.OR
|| sqlKind == SqlKind.DESCENDING
|| (sqlKind == SqlKind.OTHER_FUNCTION && operator.getName().equals("POWER"))
|| (sqlKind == SqlKind.MINUS_PREFIX);
|| (sqlKind == SqlKind.MINUS_PREFIX)
|| sqlKind == SqlKind.CASE;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.alibaba.graphscope.common.ir.tools;

import com.alibaba.graphscope.common.ir.rex.operator.CaseOperator;

import org.apache.calcite.sql.*;
import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
Expand Down Expand Up @@ -175,4 +177,6 @@ public class GraphStdOperatorTable extends SqlStdOperatorTable {
ReturnTypes.BOOLEAN_NULLABLE,
GraphInferTypes.FIRST_KNOWN,
OperandTypes.COMPARABLE_ORDERED_COMPARABLE_ORDERED);

public static final SqlOperator CASE = new CaseOperator(GraphInferTypes.RETURN_TYPE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import com.alibaba.graphscope.cypher.antlr4.visitor.type.ExprVisitorResult;
import com.alibaba.graphscope.grammar.CypherGSBaseVisitor;
import com.alibaba.graphscope.grammar.CypherGSParser;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
Expand Down Expand Up @@ -255,6 +257,35 @@ public ExprVisitorResult visitOC_CountAny(CypherGSParser.OC_CountAnyContext ctx)
RexTmpVariable.of(alias, ((GraphAggCall) aggCall).getType()));
}

@Override
public ExprVisitorResult visitOC_CaseExpression(CypherGSParser.OC_CaseExpressionContext ctx) {
ExprVisitorResult inputExpr =
ctx.oC_InputExpression() == null
? null
: visitOC_InputExpression(ctx.oC_InputExpression());
List<RexNode> operands = Lists.newArrayList();
for (CypherGSParser.OC_CaseAlternativeContext whenThen : ctx.oC_CaseAlternative()) {
Preconditions.checkArgument(
whenThen.oC_Expression().size() == 2,
"whenThen expression should have 2 parts");
ExprVisitorResult whenExpr = visitOC_Expression(whenThen.oC_Expression(0));
if (inputExpr != null) {
operands.add(builder.equals(inputExpr.getExpr(), whenExpr.getExpr()));
} else {
operands.add(whenExpr.getExpr());
}
ExprVisitorResult thenExpr = visitOC_Expression(whenThen.oC_Expression(1));
operands.add(thenExpr.getExpr());
}
// if else expression is omitted, the default value is null
ExprVisitorResult elseExpr =
ctx.oC_ElseExpression() == null
? new ExprVisitorResult(builder.literal(null))
: visitOC_ElseExpression(ctx.oC_ElseExpression());
operands.add(elseExpr.getExpr());
return new ExprVisitorResult(builder.call(GraphStdOperatorTable.CASE, operands));
}

private ExprVisitorResult binaryCall(
List<SqlOperator> operators, List<ExprVisitorResult> operands) {
ObjectUtils.requireNonEmpty(operands, "operands count should not be 0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,22 @@ private GraphInferTypes() {}
}
Arrays.fill(operandTypes, knownType);
};

/**
* Operand type-inference strategy where an unknown operand type is derived
* from the call's return type. If the return type is a record, it must have
* the same number of fields as the number of operands.
*/
public static final SqlOperandTypeInference RETURN_TYPE =
(callBinding, returnType, operandTypes) -> {
RelDataType unknownType = callBinding.getTypeFactory().createUnknownType();
for (int i = 0; i < operandTypes.length; ++i) {
if (operandTypes[i].equals(unknownType)) {
operandTypes[i] =
returnType.isStruct()
? returnType.getFieldList().get(i).getType()
: returnType;
}
}
};
}
Loading

0 comments on commit 7fc14b5

Please sign in to comment.