diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCastFunctionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCastFunctionTest.java index 3a0da79657f..95fc53b588f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCastFunctionTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCastFunctionTest.java @@ -28,7 +28,7 @@ public void testCast() { verifyLogical(root, expectedLogical); // TODO there is no SAFE_CAST() in Spark, the Spark CAST is always safe (return null). - String expectedSparkSql = "SELECT SAFE_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`"; + String expectedSparkSql = "SELECT TRY_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -40,7 +40,7 @@ public void testCastInsensitive() { "" + "LogicalProject(a=[SAFE_CAST($3)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); - String expectedSparkSql = "SELECT SAFE_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`"; + String expectedSparkSql = "SELECT TRY_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -56,7 +56,7 @@ public void testCastOverriding() { String expectedSparkSql = "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " SAFE_CAST(`MGR` AS STRING) `age`\n" + + " TRY_CAST(`MGR` AS STRING) `age`\n" + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -83,7 +83,7 @@ public void testChainedCast() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "" + "SELECT SAFE_CAST(SAFE_CAST(`MGR` AS STRING) AS INTEGER) `a`\n" + "FROM `scott`.`EMP`"; + "" + "SELECT TRY_CAST(TRY_CAST(`MGR` AS STRING) AS INTEGER) `a`\n" + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -117,7 +117,7 @@ public void testChainedCast2() { String expectedSparkSql = "" - + "SELECT SAFE_CAST(CONCAT(SAFE_CAST(`MGR` AS STRING), '0') AS INTEGER) `a`\n" + + "SELECT TRY_CAST(CONCAT(TRY_CAST(`MGR` AS STRING), '0') AS INTEGER) `a`\n" + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLChartTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLChartTest.java index 338b586ba29..bddcde11e18 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLChartTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLChartTest.java @@ -111,14 +111,14 @@ public void testChartWithMultipleGroupKeys() { "SELECT `t2`.`gender`, CASE WHEN `t2`.`age` IS NULL THEN 'NULL' WHEN" + " `t9`.`_row_number_chart_` <= 10 THEN `t2`.`age` ELSE 'OTHER' END `age`," + " AVG(`t2`.`avg(balance)`) `avg(balance)`\n" - + "FROM (SELECT `gender`, SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`)" + + "FROM (SELECT `gender`, TRY_CAST(`age` AS STRING) `age`, AVG(`balance`)" + " `avg(balance)`\n" + "FROM `scott`.`bank`\n" + "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n" + "GROUP BY `gender`, `age`) `t2`\n" + "LEFT JOIN (SELECT `age`, SUM(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER" + " (ORDER BY SUM(`avg(balance)`) DESC) `_row_number_chart_`\n" - + "FROM (SELECT SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n" + + "FROM (SELECT TRY_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n" + "FROM `scott`.`bank`\n" + "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n" + "GROUP BY `gender`, `age`) `t6`\n" @@ -139,14 +139,14 @@ public void testChartWithMultipleGroupKeysAlternativeSyntax() { "SELECT `t2`.`gender`, CASE WHEN `t2`.`age` IS NULL THEN 'NULL' WHEN" + " `t9`.`_row_number_chart_` <= 10 THEN `t2`.`age` ELSE 'OTHER' END `age`," + " AVG(`t2`.`avg(balance)`) `avg(balance)`\n" - + "FROM (SELECT `gender`, SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`)" + + "FROM (SELECT `gender`, TRY_CAST(`age` AS STRING) `age`, AVG(`balance`)" + " `avg(balance)`\n" + "FROM `scott`.`bank`\n" + "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n" + "GROUP BY `gender`, `age`) `t2`\n" + "LEFT JOIN (SELECT `age`, SUM(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER" + " (ORDER BY SUM(`avg(balance)`) DESC) `_row_number_chart_`\n" - + "FROM (SELECT SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n" + + "FROM (SELECT TRY_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n" + "FROM `scott`.`bank`\n" + "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n" + "GROUP BY `gender`, `age`) `t6`\n" diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsEarliestLatestTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsEarliestLatestTest.java index d91a8638cb2..f76a7af2c79 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsEarliestLatestTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsEarliestLatestTest.java @@ -48,7 +48,7 @@ public void testEventstatsEarliestWithoutSecondArgument() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`," + " `@timestamp`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" + " `earliest_message`\n" + "FROM `POST`.`LOGS`"; @@ -66,7 +66,7 @@ public void testEventstatsLatestWithoutSecondArgument() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`," + " `@timestamp`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" + " `latest_message`\n" + "FROM `POST`.`LOGS`"; @@ -84,7 +84,7 @@ public void testEventstatsEarliestByServerWithoutSecondArgument() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`," + " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND" + " UNBOUNDED FOLLOWING) `earliest_message`\n" + "FROM `POST`.`LOGS`"; @@ -102,7 +102,7 @@ public void testEventstatsLatestByServerWithoutSecondArgument() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`," + " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND" + " UNBOUNDED FOLLOWING) `latest_message`\n" + "FROM `POST`.`LOGS`"; @@ -122,7 +122,7 @@ public void testEventstatsEarliestWithOtherAggregatesWithoutSecondArgument() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`," + " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND" + " UNBOUNDED FOLLOWING) `earliest_message`, COUNT(*) OVER (PARTITION BY `server` RANGE" + " BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `cnt`\n" @@ -141,7 +141,7 @@ public void testEventstatsEarliestWithExplicitTimestampField() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`," + " `created_at`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" + " `earliest_message`\n" + "FROM `POST`.`LOGS`"; @@ -159,7 +159,7 @@ public void testEventstatsLatestWithExplicitTimestampField() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`," + " `created_at`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" + " `latest_message`\n" + "FROM `POST`.`LOGS`"; @@ -180,9 +180,9 @@ public void testEventstatsEarliestLatestCombined() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`," + "SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`," + " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND" - + " UNBOUNDED FOLLOWING) `earliest_msg`, MAX_BY (`message`, `@timestamp`) OVER" + + " UNBOUNDED FOLLOWING) `earliest_msg`, MAX_BY(`message`, `@timestamp`) OVER" + " (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" + " `latest_msg`\n" + "FROM `POST`.`LOGS`"; diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java index d72c3b086cc..f1dfd930a82 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java @@ -64,9 +64,9 @@ public void testPatternsLabelMode_ShowNumberedToken_ForSimplePatternMethod() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + "SELECT `ENAME`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + " '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END, `ENAME`)['pattern'] AS" - + " STRING) `patterns_field`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR" + + " STRING) `patterns_field`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR" + " `ENAME` = '' THEN '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END," + " `ENAME`)['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >) `tokens`\n" + "FROM `scott`.`EMP`"; @@ -91,9 +91,9 @@ public void testPatternsLabelModeWithCustomPattern_ShowNumberedToken_ForSimplePa verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + "SELECT `ENAME`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + " '' ELSE REGEXP_REPLACE(`ENAME`, '[A-H]', '<*>') END, `ENAME`)['pattern'] AS STRING)" - + " `patterns_field`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` =" + + " `patterns_field`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` =" + " '' THEN '' ELSE REGEXP_REPLACE(`ENAME`, '[A-H]', '<*>') END, `ENAME`)['tokens'] AS" + " MAP< VARCHAR, VARCHAR ARRAY >) `tokens`\n" + "FROM `scott`.`EMP`"; @@ -138,9 +138,9 @@ public void testPatternsLabelModeWithPartitionBy_ShowNumberedToken_SimplePattern verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, `DEPTNO`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME`" + "SELECT `ENAME`, `DEPTNO`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME`" + " = '' THEN '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END," - + " `ENAME`)['pattern'] AS STRING) `patterns_field`, SAFE_CAST(`PATTERN_PARSER`(CASE" + + " `ENAME`)['pattern'] AS STRING) `patterns_field`, TRY_CAST(`PATTERN_PARSER`(CASE" + " WHEN `ENAME` IS NULL OR `ENAME` = '' THEN '' ELSE REGEXP_REPLACE(`ENAME`," + " '[a-zA-Z0-9]+', '<*>') END, `ENAME`)['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >)" + " `tokens`\n" @@ -160,7 +160,7 @@ public void testPatternsLabelMode_NotShowNumberedToken_ForBrainMethod() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, FALSE)" + "SELECT `ENAME`, TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, FALSE)" + " OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), FALSE)['pattern']" + " AS STRING) `patterns_field`\n" + "FROM `scott`.`EMP`"; @@ -183,9 +183,9 @@ public void testPatternsLabelMode_ShowNumberedToken_ForBrainMethod() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, TRUE)" + "SELECT `ENAME`, TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, TRUE)" + " OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), TRUE)['pattern']" - + " AS STRING) `patterns_field`, SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`," + + " AS STRING) `patterns_field`, TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`," + " 10, 100000, TRUE) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)," + " TRUE)['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >) `tokens`\n" + "FROM `scott`.`EMP`"; @@ -206,7 +206,7 @@ public void testPatternsLabelModeWithPartitionBy_NotShowNumberedToken_ForBrainMe verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, `DEPTNO`, SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10," + "SELECT `ENAME`, `DEPTNO`, TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10," + " 100000, FALSE) OVER (PARTITION BY `DEPTNO` RANGE BETWEEN UNBOUNDED PRECEDING AND" + " UNBOUNDED FOLLOWING), FALSE)['pattern'] AS STRING) `patterns_field`\n" + "FROM `scott`.`EMP`"; @@ -229,10 +229,10 @@ public void testPatternsLabelModeWithPartitionBy_ShowNumberedToken_ForBrainMetho verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `ENAME`, `DEPTNO`, SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10," + "SELECT `ENAME`, `DEPTNO`, TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10," + " 100000, TRUE) OVER (PARTITION BY `DEPTNO` RANGE BETWEEN UNBOUNDED PRECEDING AND" + " UNBOUNDED FOLLOWING), TRUE)['pattern'] AS STRING) `patterns_field`," - + " SAFE_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, TRUE) OVER" + + " TRY_CAST(`PATTERN_PARSER`(`ENAME`, `pattern`(`ENAME`, 10, 100000, TRUE) OVER" + " (PARTITION BY `DEPTNO` RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)," + " TRUE)['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >) `tokens`\n" + "FROM `scott`.`EMP`"; @@ -281,11 +281,11 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForSimplePatternMethod verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN '' ELSE" + "SELECT TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN '' ELSE" + " REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END, `TAKE`(`ENAME`, 10))['pattern']" + " AS STRING) `patterns_field`, COUNT(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + " '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END) `pattern_count`," - + " SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN '' ELSE" + + " TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN '' ELSE" + " REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END, `TAKE`(`ENAME`, 10))['tokens']" + " AS MAP< VARCHAR, VARCHAR ARRAY >) `tokens`, `TAKE`(`ENAME`, 10) `sample_logs`\n" + "FROM `scott`.`EMP`\n" @@ -312,11 +312,11 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForSimplePa verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `DEPTNO`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + "SELECT `DEPTNO`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` = '' THEN" + " '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END, `TAKE`(`ENAME`," + " 10))['pattern'] AS STRING) `patterns_field`, COUNT(CASE WHEN `ENAME` IS NULL OR" + " `ENAME` = '' THEN '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END)" - + " `pattern_count`, SAFE_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` =" + + " `pattern_count`, TRY_CAST(`PATTERN_PARSER`(CASE WHEN `ENAME` IS NULL OR `ENAME` =" + " '' THEN '' ELSE REGEXP_REPLACE(`ENAME`, '[a-zA-Z0-9]+', '<*>') END, `TAKE`(`ENAME`," + " 10))['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >) `tokens`, `TAKE`(`ENAME`, 10)" + " `sample_logs`\n" @@ -344,14 +344,10 @@ public void testPatternsAggregationMode_NotShowNumberedToken_ForBrainMethod() { + " LogicalValues(tuples=[[{ 0 }]])\n"; verifyLogical(root, expectedLogical); - /* - * TODO: Fix Spark SQL conformance - * Spark doesn't have SAFE_CAST and UNNEST - */ String expectedSparkSql = - "SELECT SAFE_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," - + " SAFE_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," - + " SAFE_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >) `sample_logs`\n" + "SELECT TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," + + " TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," + + " TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >) `sample_logs`\n" + "FROM (SELECT `pattern`(`ENAME`, 10, 100000, FALSE) `patterns_field`\n" + "FROM `scott`.`EMP`) `$cor0`,\n" + "LATERAL UNNEST((SELECT `$cor0`.`patterns_field`\n" @@ -378,15 +374,11 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForBrainMethod() { + " LogicalValues(tuples=[[{ 0 }]])\n"; verifyLogical(root, expectedLogical); - /* - * TODO: Fix Spark SQL conformance - * Spark doesn't have SAFE_CAST and UNNEST - */ String expectedSparkSql = - "SELECT SAFE_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," - + " SAFE_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," - + " SAFE_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >)" - + " `tokens`, SAFE_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >)" + "SELECT TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," + + " TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," + + " TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >)" + + " `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >)" + " `sample_logs`\n" + "FROM (SELECT `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`) `$cor0`,\n" @@ -414,14 +406,10 @@ public void testPatternsAggregationModeWithGroupBy_NotShowNumberedToken_ForBrain + " LogicalValues(tuples=[[{ 0 }]])\n"; verifyLogical(root, expectedLogical); - /* - * TODO: Fix Spark SQL conformance - * Spark doesn't have SAFE_CAST and UNNEST - */ String expectedSparkSql = - "SELECT `$cor0`.`DEPTNO`, SAFE_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" - + " `patterns_field`, SAFE_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" - + " `pattern_count`, SAFE_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING" + "SELECT `$cor0`.`DEPTNO`, TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" + + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" + + " `pattern_count`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING" + " >) `sample_logs`\n" + "FROM (SELECT `DEPTNO`, `pattern`(`ENAME`, 10, 100000, FALSE) `patterns_field`\n" + "FROM `scott`.`EMP`\n" @@ -451,15 +439,11 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForBrainMet + " LogicalValues(tuples=[[{ 0 }]])\n"; verifyLogical(root, expectedLogical); - /* - * TODO: Fix Spark SQL conformance - * Spark doesn't have SAFE_CAST and UNNEST - */ String expectedSparkSql = - "SELECT `$cor0`.`DEPTNO`, SAFE_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" - + " `patterns_field`, SAFE_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" - + " `pattern_count`, SAFE_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR," - + " VARCHAR ARRAY >) `tokens`, SAFE_CAST(`t20`.`patterns_field`['sample_logs'] AS" + "SELECT `$cor0`.`DEPTNO`, TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" + + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" + + " `pattern_count`, TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR," + + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS" + " ARRAY< STRING >) `sample_logs`\n" + "FROM (SELECT `DEPTNO`, `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`\n" diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStatsEarliestLatestTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStatsEarliestLatestTest.java index f5ee3780411..cba3942da3a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStatsEarliestLatestTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStatsEarliestLatestTest.java @@ -51,7 +51,7 @@ public void testEarliestWithoutSecondArgument() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`\n" + "FROM `POST`.`LOGS`"; + "SELECT MIN_BY(`message`, `@timestamp`) `earliest_message`\n" + "FROM `POST`.`LOGS`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -69,7 +69,7 @@ public void testLatestWithoutSecondArgument() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MAX_BY (`message`, `@timestamp`) `latest_message`\n" + "FROM `POST`.`LOGS`"; + "SELECT MAX_BY(`message`, `@timestamp`) `latest_message`\n" + "FROM `POST`.`LOGS`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -91,7 +91,7 @@ public void testEarliestByServerWithoutSecondArgument() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`, `server`\n" + "SELECT MIN_BY(`message`, `@timestamp`) `earliest_message`, `server`\n" + "FROM `POST`.`LOGS`\n" + "GROUP BY `server`"; verifyPPLToSparkSQL(root, expectedSparkSql); @@ -115,7 +115,7 @@ public void testLatestByServerWithoutSecondArgument() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MAX_BY (`message`, `@timestamp`) `latest_message`, `server`\n" + "SELECT MAX_BY(`message`, `@timestamp`) `latest_message`, `server`\n" + "FROM `POST`.`LOGS`\n" + "GROUP BY `server`"; verifyPPLToSparkSQL(root, expectedSparkSql); @@ -140,7 +140,7 @@ public void testEarliestWithOtherAggregatesWithoutSecondArgument() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MIN_BY (`message`, `@timestamp`) `earliest_message`, " + "SELECT MIN_BY(`message`, `@timestamp`) `earliest_message`, " + "COUNT(*) `cnt`, `server`\n" + "FROM `POST`.`LOGS`\n" + "GROUP BY `server`"; @@ -161,7 +161,7 @@ public void testEarliestWithExplicitTimestampField() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MIN_BY (`message`, `created_at`) `earliest_message`\n" + "FROM `POST`.`LOGS`"; + "SELECT MIN_BY(`message`, `created_at`) `earliest_message`\n" + "FROM `POST`.`LOGS`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -179,7 +179,7 @@ public void testLatestWithExplicitTimestampField() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT MAX_BY (`message`, `created_at`) `latest_message`\n" + "FROM `POST`.`LOGS`"; + "SELECT MAX_BY(`message`, `created_at`) `latest_message`\n" + "FROM `POST`.`LOGS`"; verifyPPLToSparkSQL(root, expectedSparkSql); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStringFunctionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStringFunctionTest.java index f67f7601dfe..46e35290704 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStringFunctionTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLStringFunctionTest.java @@ -76,7 +76,7 @@ public void testToStringFormatNotSpecified() { verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT CAST(`MGR` AS STRING) `string_value`, SAFE_CAST(`MGR` AS STRING) `cast_value`\n" + "SELECT CAST(`MGR` AS STRING) `string_value`, TRY_CAST(`MGR` AS STRING) `cast_value`\n" + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/OpenSearchSparkSqlDialect.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/OpenSearchSparkSqlDialect.java index 24ddedd2562..2d044da58e6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/OpenSearchSparkSqlDialect.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/OpenSearchSparkSqlDialect.java @@ -24,7 +24,10 @@ public class OpenSearchSparkSqlDialect extends SparkSqlDialect { private static final Map CALCITE_TO_SPARK_MAPPING = ImmutableMap.of( "ARG_MIN", "MIN_BY", - "ARG_MAX", "MAX_BY"); + "ARG_MAX", "MAX_BY", + "SAFE_CAST", "TRY_CAST"); + + private static final Map CALL_SEPARATOR = ImmutableMap.of("SAFE_CAST", "AS"); private OpenSearchSparkSqlDialect() { super(DEFAULT_CONTEXT); @@ -37,21 +40,31 @@ public void unparseCall(SqlWriter writer, SqlCall call, int leftPrec, int rightP // Replace Calcite specific functions with their Spark SQL equivalents if (CALCITE_TO_SPARK_MAPPING.containsKey(operatorName)) { unparseFunction( - writer, call, CALCITE_TO_SPARK_MAPPING.get(operatorName), leftPrec, rightPrec); + writer, + call, + CALCITE_TO_SPARK_MAPPING.get(operatorName), + leftPrec, + rightPrec, + CALL_SEPARATOR.getOrDefault(operatorName, ",")); } else { super.unparseCall(writer, call, leftPrec, rightPrec); } } private void unparseFunction( - SqlWriter writer, SqlCall call, String functionName, int leftPrec, int rightPrec) { - writer.keyword(functionName); + SqlWriter writer, + SqlCall call, + String functionName, + int leftPrec, + int rightPrec, + String separator) { + writer.print(functionName); final SqlWriter.Frame frame = writer.startList("(", ")"); for (int i = 0; i < call.operandCount(); i++) { if (i > 0) { - writer.sep(","); + writer.sep(separator); } - call.operand(i).unparse(writer, leftPrec, rightPrec); + call.operand(i).unparse(writer, 0, rightPrec); } writer.endList(frame); }