Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -470,4 +470,13 @@ public static String uuid(byte[] b) {
public static boolean valueEquals(Object object1, Object object2) {
return (object1 != null && object2 != null) && object1.equals(object2);
}

public static Object coalesce(Object... objects) {
for (Object item : objects) {
if (item != null) {
return item;
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.type.SqlTypeName;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.commons.compiler.Location;
Expand Down Expand Up @@ -82,17 +83,51 @@ public static ExpressionEvaluator compileExpression(
}

public static String translateSqlNodeToJaninoExpression(SqlNode transform) {
if (transform instanceof SqlIdentifier) {
SqlIdentifier sqlIdentifier = (SqlIdentifier) transform;
return sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
} else if (transform instanceof SqlBasicCall) {
Java.Rvalue rvalue = translateJaninoAST((SqlBasicCall) transform);
Java.Rvalue rvalue = translateSqlNodeToJaninoRvalue(transform);
if (rvalue != null) {
return rvalue.toString();
}
return "";
}

private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) {
public static Java.Rvalue translateSqlNodeToJaninoRvalue(SqlNode transform) {
if (transform instanceof SqlIdentifier) {
return translateSqlIdentifier((SqlIdentifier) transform);
} else if (transform instanceof SqlBasicCall) {
return translateSqlBasicCall((SqlBasicCall) transform);
} else if (transform instanceof SqlCase) {
return translateSqlCase((SqlCase) transform);
} else if (transform instanceof SqlLiteral) {
return translateSqlSqlLiteral((SqlLiteral) transform);
}
return null;
}

private static Java.Rvalue translateSqlIdentifier(SqlIdentifier sqlIdentifier) {
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
return generateNoOperandTimestampFunctionOperation(columnName);
} else {
return new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName});
}
}

private static Java.Rvalue translateSqlSqlLiteral(SqlLiteral sqlLiteral) {
if (sqlLiteral.getValue() == null) {
return new Java.NullLiteral(Location.NOWHERE);
}
String value = sqlLiteral.getValue().toString();
if (sqlLiteral instanceof SqlCharStringLiteral) {
// Double quotation marks represent strings in Janino.
value = "\"" + value.substring(1, value.length() - 1) + "\"";
}
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
value = "\"" + value + "\"";
}
return new Java.AmbiguousName(Location.NOWHERE, new String[] {value});
}

private static Java.Rvalue translateSqlBasicCall(SqlBasicCall sqlBasicCall) {
List<SqlNode> operandList = sqlBasicCall.getOperandList();
List<Java.Rvalue> atoms = new ArrayList<>();
for (SqlNode sqlNode : operandList) {
Expand All @@ -105,32 +140,44 @@ private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) {
return sqlBasicCallToJaninoRvalue(sqlBasicCall, atoms.toArray(new Java.Rvalue[0]));
}

private static Java.Rvalue translateSqlCase(SqlCase sqlCase) {
SqlNodeList whenOperands = sqlCase.getWhenOperands();
SqlNodeList thenOperands = sqlCase.getThenOperands();
SqlNode elseOperand = sqlCase.getElseOperand();
List<Java.Rvalue> whenAtoms = new ArrayList<>();
for (SqlNode sqlNode : whenOperands) {
translateSqlNodeToAtoms(sqlNode, whenAtoms);
}
List<Java.Rvalue> thenAtoms = new ArrayList<>();
for (SqlNode sqlNode : thenOperands) {
translateSqlNodeToAtoms(sqlNode, thenAtoms);
}
Java.Rvalue elseAtoms = translateSqlNodeToJaninoRvalue(elseOperand);
Java.Rvalue sqlCaseRvalueTemp = elseAtoms;
for (int i = whenAtoms.size() - 1; i >= 0; i--) {
sqlCaseRvalueTemp =
new Java.ConditionalExpression(
Location.NOWHERE,
whenAtoms.get(i),
thenAtoms.get(i),
sqlCaseRvalueTemp);
}
return new Java.ParenthesizedExpression(Location.NOWHERE, sqlCaseRvalueTemp);
}

private static void translateSqlNodeToAtoms(SqlNode sqlNode, List<Java.Rvalue> atoms) {
if (sqlNode instanceof SqlIdentifier) {
SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode;
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
atoms.add(generateNoOperandTimestampFunctionOperation(columnName));
} else {
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName}));
}
atoms.add(translateSqlIdentifier((SqlIdentifier) sqlNode));
} else if (sqlNode instanceof SqlLiteral) {
SqlLiteral sqlLiteral = (SqlLiteral) sqlNode;
String value = sqlLiteral.getValue().toString();
if (sqlLiteral instanceof SqlCharStringLiteral) {
// Double quotation marks represent strings in Janino.
value = "\"" + value.substring(1, value.length() - 1) + "\"";
}
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
value = "\"" + value + "\"";
}
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {value}));
atoms.add(translateSqlSqlLiteral((SqlLiteral) sqlNode));
} else if (sqlNode instanceof SqlBasicCall) {
atoms.add(translateJaninoAST((SqlBasicCall) sqlNode));
atoms.add(translateSqlBasicCall((SqlBasicCall) sqlNode));
} else if (sqlNode instanceof SqlNodeList) {
for (SqlNode node : (SqlNodeList) sqlNode) {
translateSqlNodeToAtoms(node, atoms);
}
} else if (sqlNode instanceof SqlCase) {
atoms.add(translateSqlCase((SqlCase) sqlNode));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
Expand Down Expand Up @@ -237,10 +239,7 @@ public static String translateFilterExpressionToJaninoExpression(String filterEx
return "";
}
SqlNode where = sqlSelect.getWhere();
if (!(where instanceof SqlBasicCall)) {
throw new ParseException("Unrecognized where: " + where.toString());
}
return JaninoCompiler.translateSqlNodeToJaninoExpression((SqlBasicCall) where);
return JaninoCompiler.translateSqlNodeToJaninoExpression(where);
}

public static List<String> parseComputedColumnNames(String projection) {
Expand Down Expand Up @@ -294,11 +293,7 @@ public static List<String> parseFilterColumnNameList(String filterExpression) {
return new ArrayList<>();
}
SqlNode where = sqlSelect.getWhere();
if (!(where instanceof SqlBasicCall)) {
throw new ParseException("Unrecognized where: " + where.toString());
}
SqlBasicCall sqlBasicCall = (SqlBasicCall) where;
return parseColumnNameList(sqlBasicCall);
return parseColumnNameList(where);
}

private static List<String> parseColumnNameList(SqlNode sqlNode) {
Expand All @@ -310,6 +305,9 @@ private static List<String> parseColumnNameList(SqlNode sqlNode) {
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
findSqlIdentifier(sqlCase.getWhenOperands().getList(), columnNameList);
}
return columnNameList;
}
Expand All @@ -323,6 +321,10 @@ private static void findSqlIdentifier(List<SqlNode> sqlNodes, List<String> colum
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
SqlNodeList whenOperands = sqlCase.getWhenOperands();
findSqlIdentifier(whenOperands.getList(), columnNameList);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,23 @@ void testBuildInFunctionTransform() throws Exception {
testExpressionConditionTransform("ceil(2.4) = 3.0");
testExpressionConditionTransform("floor(2.5) = 2.0");
testExpressionConditionTransform("round(3.1415926,2) = 3.14");
testExpressionConditionTransform("IF(2>0,1,0) = 1");
testExpressionConditionTransform("COALESCE(null,1,2) = 1");
testExpressionConditionTransform("1 + 1 = 2");
testExpressionConditionTransform("1 - 1 = 0");
testExpressionConditionTransform("1 * 1 = 1");
testExpressionConditionTransform("3 % 2 = 1");
testExpressionConditionTransform("1 < 2");
testExpressionConditionTransform("1 <= 1");
testExpressionConditionTransform("1 > 0");
testExpressionConditionTransform("1 >= 1");
testExpressionConditionTransform(
"case 1 when 1 then 'a' when 2 then 'b' else 'c' end = 'a'");
testExpressionConditionTransform("case col1 when '1' then true else false end");
testExpressionConditionTransform("case when col1 = '1' then true else false end");
}

void testExpressionConditionTransform(String expression) throws Exception {
private void testExpressionConditionTransform(String expression) throws Exception {
TransformDataOperator transform =
TransformDataOperator.newBuilder()
.addTransform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory;
import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.Expressions;

import org.apache.calcite.config.CalciteConnectionConfigImpl;
import org.apache.calcite.jdbc.CalciteSchema;
Expand Down Expand Up @@ -260,13 +258,12 @@ public void testTranslateFilterToJaninoExpression() {
testFilterExpression("upper(lower(id))", "upper(lower(id))");
testFilterExpression(
"abs(uniq_id) > 10 and id is not null", "abs(uniq_id) > 10 && null != id");
}

@Test
public void testSqlCall() {
ApiExpression apiExpression = Expressions.concat("1", "2");
ApiExpression substring = apiExpression.substring(1);
System.out.println(substring);
testFilterExpression(
"case id when 1 then 'a' when 2 then 'b' else 'c' end",
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
testFilterExpression(
"case when id = 1 then 'a' when id = 2 then 'b' else 'c' end",
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
}

private void testFilterExpression(String expression, String expressionExpect) {
Expand Down