diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 53c511a87f69..9ef0d481bc9c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -364,6 +364,42 @@
*
Since version: 3.4.0
*
*
+ * Name: AES_ENCRYPT
+ *
+ * - SQL semantic:
AES_ENCRYPT(expr, key[, mode[, padding]])
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: AES_DECRYPT
+ *
+ * - SQL semantic:
AES_DECRYPT(expr, key[, mode[, padding]])
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: SHA1
+ *
+ * - SQL semantic:
SHA1(expr)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: SHA2
+ *
+ * - SQL semantic:
SHA2(expr, bitLength)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: MD5
+ *
+ * - SQL semantic:
MD5(expr)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: CRC32
+ *
+ * - SQL semantic:
CRC32(expr)
+ * - Since version: 3.4.0
+ *
+ *
*
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
* including: add, subtract, multiply, divide, remainder, pmod.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 541b88a5027d..3a78a946e36a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -149,6 +149,12 @@ public String build(Expression expr) {
case "DATE_ADD":
case "DATE_DIFF":
case "TRUNC":
+ case "AES_ENCRYPT":
+ case "AES_DECRYPT":
+ case "SHA1":
+ case "SHA2":
+ case "MD5":
+ case "CRC32":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 414155537290..a029c002d0d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -257,6 +257,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpression(child).map(v => new V2Extract("WEEK", v))
case YearOfWeek(child) =>
generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
+ case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children)
+ case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children)
+ case Crc32(child) => generateExpressionWithName("CRC32", Seq(child))
+ case Md5(child) => generateExpressionWithName("MD5", Seq(child))
+ case Sha1(child) => generateExpressionWithName("SHA1", Seq(child))
+ case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children)
// TODO supports other expressions
case ApplyFunctionExpression(function, children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 4200ba91fb1b..737e3de10a92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -50,7 +50,7 @@ private[sql] object H2Dialect extends JdbcDialect {
Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP",
"POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN",
"TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN",
- "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM")
+ "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM", "MD5", "SHA1", "SHA2")
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
@@ -235,5 +235,22 @@ private[sql] object H2Dialect extends JdbcDialect {
}
s"EXTRACT($newField FROM $source)"
}
+
+ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
+ if (isSupportedFunction(funcName)) {
+ funcName match {
+ case "MD5" =>
+ "RAWTOHEX(HASH('MD5', " + inputs.mkString(",") + "))"
+ case "SHA1" =>
+ "RAWTOHEX(HASH('SHA-1', " + inputs.mkString(",") + "))"
+ case "SHA2" =>
+ "RAWTOHEX(HASH('SHA-" + inputs(1) + "'," + inputs(0) + "))"
+ case _ => super.visitSQLFunction(funcName, inputs)
+ }
+ } else {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} does not support function: $funcName");
+ }
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 3b226d606430..d5255fa1c59f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -45,6 +45,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val tempDir = Utils.createTempDir()
val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
+ val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++
+ Array.fill(15)(0.toByte)
val testH2Dialect = new JdbcDialect {
override def canHandle(url: String): Boolean = H2Dialect.canHandle(url)
@@ -178,6 +180,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " +
"('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name TEXT(32),b BINARY(20))")
+ .executeUpdate()
+ val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" VALUES (?, ?)")
+ stmt.setString(1, "jen")
+ stmt.setBytes(2, testBytes)
+ stmt.executeUpdate()
}
H2Dialect.registerFunction("my_avg", IntegralAverage)
H2Dialect.registerFunction("my_strlen", StrLen(CharLength))
@@ -860,7 +869,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkSortRemoved(df2)
checkPushedInfo(df2,
"PushedFilters: [DEPT IS NOT NULL, DEPT > 1]",
- "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1")
+ "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1")
checkAnswer(df2, Seq(Row(2, "david", 10000.00)))
}
@@ -1190,6 +1199,52 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df8, Seq(Row("alex")))
}
+ test("scan with filter push-down with misc functions") {
+ val df1 = sql("SELECT name FROM h2.test.binary1 WHERE " +
+ "md5(b) = '4371fe0aa613bcb081543a37d241adcb'")
+ checkFiltersRemoved(df1)
+ val expectedPlanFragment1 = "PushedFilters: [B IS NOT NULL, " +
+ "MD5(B) = '4371fe0aa613bcb081543a37d241adcb']"
+ checkPushedInfo(df1, expectedPlanFragment1)
+ checkAnswer(df1, Seq(Row("jen")))
+
+ val df2 = sql("SELECT name FROM h2.test.binary1 WHERE " +
+ "sha1(b) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1'")
+ checkFiltersRemoved(df2)
+ val expectedPlanFragment2 = "PushedFilters: [B IS NOT NULL, " +
+ "SHA1(B) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1'],"
+ checkPushedInfo(df2, expectedPlanFragment2)
+ checkAnswer(df2, Seq(Row("jen")))
+
+ val df3 = sql("SELECT name FROM h2.test.binary1 WHERE " +
+ "sha2(b, 256) = '911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109'")
+ checkFiltersRemoved(df3)
+ val expectedPlanFragment3 = "PushedFilters: [B IS NOT NULL, (SHA2(B, 256)) = " +
+ "'911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109']"
+ checkPushedInfo(df3, expectedPlanFragment3)
+ checkAnswer(df3, Seq(Row("jen")))
+
+ val df4 = sql("SELECT * FROM h2.test.employee WHERE crc32(name) = '142689369'")
+ checkFiltersRemoved(df4, false)
+ val expectedPlanFragment4 = "PushedFilters: [NAME IS NOT NULL], "
+ checkPushedInfo(df4, expectedPlanFragment4)
+ checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df5 = sql("SELECT name FROM h2.test.employee WHERE " +
+ "aes_encrypt(cast(null as string), name) is null")
+ checkFiltersRemoved(df5, false)
+ val expectedPlanFragment5 = "PushedFilters: [], "
+ checkPushedInfo(df5, expectedPlanFragment5)
+ checkAnswer(df5, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen")))
+
+ val df6 = sql("SELECT name FROM h2.test.employee WHERE " +
+ "aes_decrypt(cast(null as binary), name) is null")
+ checkFiltersRemoved(df6, false)
+ val expectedPlanFragment6 = "PushedFilters: [], "
+ checkPushedInfo(df6, expectedPlanFragment6)
+ checkAnswer(df6, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen")))
+ }
+
test("scan with filter push-down with UDF") {
JdbcDialects.unregisterDialect(H2Dialect)
try {
@@ -1269,7 +1324,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
Seq(Row("test", "people", false), Row("test", "empty_table", false),
Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false),
Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false),
- Row("test", "datetime", false)))
+ Row("test", "datetime", false), Row("test", "binary1", false)))
}
test("SQL API: create table as select") {
@@ -1819,12 +1874,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkFiltersRemoved(df)
checkAggregateRemoved(df)
checkPushedInfo(df,
- """
- |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS),
- |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)],
- |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
- |PushedGroupByExpressions: [DEPT],
- |""".stripMargin.replaceAll("\n", " "))
+ """
+ |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS),
+ |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)],
+ |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+ |PushedGroupByExpressions: [DEPT],
+ |""".stripMargin.replaceAll("\n", " "))
checkAnswer(df, Seq(Row(10000d, 10000d, 20000d, 20000d),
Row(2500d, 2500d, 5000d, 5000d), Row(0d, 0d, null, null)))
}