diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/SystemFunctionUtils.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/SystemFunctionUtils.java index 03128e3993c..724bc3e14df 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/SystemFunctionUtils.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/functions/SystemFunctionUtils.java @@ -470,4 +470,13 @@ public static String uuid(byte[] b) { public static boolean valueEquals(Object object1, Object object2) { return (object1 != null && object2 != null) && object1.equals(object2); } + + public static Object coalesce(Object... objects) { + for (Object item : objects) { + if (item != null) { + return item; + } + } + return null; + } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java index 373c097bb3f..5af9755edf3 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/JaninoCompiler.java @@ -27,6 +27,7 @@ import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.fun.SqlCase; import org.apache.calcite.sql.type.SqlTypeName; import org.codehaus.commons.compiler.CompileException; import org.codehaus.commons.compiler.Location; @@ -82,17 +83,51 @@ public static ExpressionEvaluator compileExpression( } public static String translateSqlNodeToJaninoExpression(SqlNode transform) { - if (transform instanceof SqlIdentifier) { - SqlIdentifier sqlIdentifier = (SqlIdentifier) transform; - return sqlIdentifier.names.get(sqlIdentifier.names.size() - 1); - } else if (transform instanceof SqlBasicCall) { - Java.Rvalue rvalue = translateJaninoAST((SqlBasicCall) transform); + Java.Rvalue rvalue = translateSqlNodeToJaninoRvalue(transform); + if (rvalue != null) { return rvalue.toString(); } return ""; } - private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) { + public static Java.Rvalue translateSqlNodeToJaninoRvalue(SqlNode transform) { + if (transform instanceof SqlIdentifier) { + return translateSqlIdentifier((SqlIdentifier) transform); + } else if (transform instanceof SqlBasicCall) { + return translateSqlBasicCall((SqlBasicCall) transform); + } else if (transform instanceof SqlCase) { + return translateSqlCase((SqlCase) transform); + } else if (transform instanceof SqlLiteral) { + return translateSqlSqlLiteral((SqlLiteral) transform); + } + return null; + } + + private static Java.Rvalue translateSqlIdentifier(SqlIdentifier sqlIdentifier) { + String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1); + if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) { + return generateNoOperandTimestampFunctionOperation(columnName); + } else { + return new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName}); + } + } + + private static Java.Rvalue translateSqlSqlLiteral(SqlLiteral sqlLiteral) { + if (sqlLiteral.getValue() == null) { + return new Java.NullLiteral(Location.NOWHERE); + } + String value = sqlLiteral.getValue().toString(); + if (sqlLiteral instanceof SqlCharStringLiteral) { + // Double quotation marks represent strings in Janino. + value = "\"" + value.substring(1, value.length() - 1) + "\""; + } + if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) { + value = "\"" + value + "\""; + } + return new Java.AmbiguousName(Location.NOWHERE, new String[] {value}); + } + + private static Java.Rvalue translateSqlBasicCall(SqlBasicCall sqlBasicCall) { List operandList = sqlBasicCall.getOperandList(); List atoms = new ArrayList<>(); for (SqlNode sqlNode : operandList) { @@ -105,32 +140,44 @@ private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) { return sqlBasicCallToJaninoRvalue(sqlBasicCall, atoms.toArray(new Java.Rvalue[0])); } + private static Java.Rvalue translateSqlCase(SqlCase sqlCase) { + SqlNodeList whenOperands = sqlCase.getWhenOperands(); + SqlNodeList thenOperands = sqlCase.getThenOperands(); + SqlNode elseOperand = sqlCase.getElseOperand(); + List whenAtoms = new ArrayList<>(); + for (SqlNode sqlNode : whenOperands) { + translateSqlNodeToAtoms(sqlNode, whenAtoms); + } + List thenAtoms = new ArrayList<>(); + for (SqlNode sqlNode : thenOperands) { + translateSqlNodeToAtoms(sqlNode, thenAtoms); + } + Java.Rvalue elseAtoms = translateSqlNodeToJaninoRvalue(elseOperand); + Java.Rvalue sqlCaseRvalueTemp = elseAtoms; + for (int i = whenAtoms.size() - 1; i >= 0; i--) { + sqlCaseRvalueTemp = + new Java.ConditionalExpression( + Location.NOWHERE, + whenAtoms.get(i), + thenAtoms.get(i), + sqlCaseRvalueTemp); + } + return new Java.ParenthesizedExpression(Location.NOWHERE, sqlCaseRvalueTemp); + } + private static void translateSqlNodeToAtoms(SqlNode sqlNode, List atoms) { if (sqlNode instanceof SqlIdentifier) { - SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode; - String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1); - if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) { - atoms.add(generateNoOperandTimestampFunctionOperation(columnName)); - } else { - atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName})); - } + atoms.add(translateSqlIdentifier((SqlIdentifier) sqlNode)); } else if (sqlNode instanceof SqlLiteral) { - SqlLiteral sqlLiteral = (SqlLiteral) sqlNode; - String value = sqlLiteral.getValue().toString(); - if (sqlLiteral instanceof SqlCharStringLiteral) { - // Double quotation marks represent strings in Janino. - value = "\"" + value.substring(1, value.length() - 1) + "\""; - } - if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) { - value = "\"" + value + "\""; - } - atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {value})); + atoms.add(translateSqlSqlLiteral((SqlLiteral) sqlNode)); } else if (sqlNode instanceof SqlBasicCall) { - atoms.add(translateJaninoAST((SqlBasicCall) sqlNode)); + atoms.add(translateSqlBasicCall((SqlBasicCall) sqlNode)); } else if (sqlNode instanceof SqlNodeList) { for (SqlNode node : (SqlNodeList) sqlNode) { translateSqlNodeToAtoms(node, atoms); } + } else if (sqlNode instanceof SqlCase) { + atoms.add(translateSqlCase((SqlCase) sqlNode)); } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java index d89e0e359c6..feb791157a6 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java @@ -43,7 +43,9 @@ import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.fun.SqlCase; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; @@ -237,10 +239,7 @@ public static String translateFilterExpressionToJaninoExpression(String filterEx return ""; } SqlNode where = sqlSelect.getWhere(); - if (!(where instanceof SqlBasicCall)) { - throw new ParseException("Unrecognized where: " + where.toString()); - } - return JaninoCompiler.translateSqlNodeToJaninoExpression((SqlBasicCall) where); + return JaninoCompiler.translateSqlNodeToJaninoExpression(where); } public static List parseComputedColumnNames(String projection) { @@ -294,11 +293,7 @@ public static List parseFilterColumnNameList(String filterExpression) { return new ArrayList<>(); } SqlNode where = sqlSelect.getWhere(); - if (!(where instanceof SqlBasicCall)) { - throw new ParseException("Unrecognized where: " + where.toString()); - } - SqlBasicCall sqlBasicCall = (SqlBasicCall) where; - return parseColumnNameList(sqlBasicCall); + return parseColumnNameList(where); } private static List parseColumnNameList(SqlNode sqlNode) { @@ -310,6 +305,9 @@ private static List parseColumnNameList(SqlNode sqlNode) { } else if (sqlNode instanceof SqlBasicCall) { SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode; findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList); + } else if (sqlNode instanceof SqlCase) { + SqlCase sqlCase = (SqlCase) sqlNode; + findSqlIdentifier(sqlCase.getWhenOperands().getList(), columnNameList); } return columnNameList; } @@ -323,6 +321,10 @@ private static void findSqlIdentifier(List sqlNodes, List colum } else if (sqlNode instanceof SqlBasicCall) { SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode; findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList); + } else if (sqlNode instanceof SqlCase) { + SqlCase sqlCase = (SqlCase) sqlNode; + SqlNodeList whenOperands = sqlCase.getWhenOperands(); + findSqlIdentifier(whenOperands.getList(), columnNameList); } } } diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/TransformDataOperatorTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/TransformDataOperatorTest.java index 5f6d5e921e4..f8fd206bd09 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/TransformDataOperatorTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/TransformDataOperatorTest.java @@ -578,9 +578,23 @@ void testBuildInFunctionTransform() throws Exception { testExpressionConditionTransform("ceil(2.4) = 3.0"); testExpressionConditionTransform("floor(2.5) = 2.0"); testExpressionConditionTransform("round(3.1415926,2) = 3.14"); + testExpressionConditionTransform("IF(2>0,1,0) = 1"); + testExpressionConditionTransform("COALESCE(null,1,2) = 1"); + testExpressionConditionTransform("1 + 1 = 2"); + testExpressionConditionTransform("1 - 1 = 0"); + testExpressionConditionTransform("1 * 1 = 1"); + testExpressionConditionTransform("3 % 2 = 1"); + testExpressionConditionTransform("1 < 2"); + testExpressionConditionTransform("1 <= 1"); + testExpressionConditionTransform("1 > 0"); + testExpressionConditionTransform("1 >= 1"); + testExpressionConditionTransform( + "case 1 when 1 then 'a' when 2 then 'b' else 'c' end = 'a'"); + testExpressionConditionTransform("case col1 when '1' then true else false end"); + testExpressionConditionTransform("case when col1 = '1' then true else false end"); } - void testExpressionConditionTransform(String expression) throws Exception { + private void testExpressionConditionTransform(String expression) throws Exception { TransformDataOperator transform = TransformDataOperator.newBuilder() .addTransform( diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/TransformParserTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/TransformParserTest.java index 312493cfbd7..9dffeb84a17 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/TransformParserTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/parser/TransformParserTest.java @@ -21,8 +21,6 @@ import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory; import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable; -import org.apache.flink.table.api.ApiExpression; -import org.apache.flink.table.api.Expressions; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.jdbc.CalciteSchema; @@ -260,13 +258,12 @@ public void testTranslateFilterToJaninoExpression() { testFilterExpression("upper(lower(id))", "upper(lower(id))"); testFilterExpression( "abs(uniq_id) > 10 and id is not null", "abs(uniq_id) > 10 && null != id"); - } - - @Test - public void testSqlCall() { - ApiExpression apiExpression = Expressions.concat("1", "2"); - ApiExpression substring = apiExpression.substring(1); - System.out.println(substring); + testFilterExpression( + "case id when 1 then 'a' when 2 then 'b' else 'c' end", + "(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")"); + testFilterExpression( + "case when id = 1 then 'a' when id = 2 then 'b' else 'c' end", + "(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")"); } private void testFilterExpression(String expression, String expressionExpect) {