diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index c2cb1d9102..62bd70f06e 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -47,7 +47,6 @@ use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::datetime::last_day::SparkLastDay; use datafusion_spark::function::datetime::next_day::SparkNextDay; use datafusion_spark::function::hash::sha1::SparkSha1; -use datafusion_spark::function::hash::sha2::SparkSha2; use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::math::hex::SparkHex; @@ -362,7 +361,6 @@ fn prepare_datafusion_session_context( // register UDFs from datafusion-spark crate fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); - session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index b059199735..cf45135637 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -68,6 +68,16 @@ object CometMurmur3Hash extends CometExpressionSerde[Murmur3Hash] { } object CometSha2 extends CometExpressionSerde[Sha2] { + override def getSupportLevel(expr: Sha2): SupportLevel = { + // If all children are foldable (constant/scalar), let Spark evaluate sha2 + // to avoid relying on DataFusion support for purely scalar invocations. + if (expr.children.forall(_.foldable)) { + Unsupported(Some("sha2 with all scalar arguments is evaluated in Spark")) + } else { + Compatible() + } + } + override def convert( expr: Sha2, inputs: Seq[Attribute], diff --git a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql index 35031ea7e4..1f715bac5c 100644 --- a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql +++ b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql @@ -28,5 +28,5 @@ query SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) FROM test -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3340) +query SELECT md5('Spark SQL'), sha1('test'), sha2('test', 256), hash('test'), xxhash64('test')