diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 12e73f2e9fe1a..a644b90a96ff6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -50,7 +50,8 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput def propagateNull: Boolean - override def foldable: Boolean = children.forall(_.foldable) && deterministic + override def foldable: Boolean = + children.forall(_.foldable) && deterministic && trustedSerializable(dataType) protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true) protected lazy val needNullCheckForIndex: Array[Boolean] = arguments.map(a => a.nullable && (propagateNull || @@ -62,6 +63,14 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput .map(cls => v => cls.cast(v)) .getOrElse(identity) + // Returns true if we can trust all values of the given DataType can be serialized. + private def trustedSerializable(dt: DataType): Boolean = { + // Right now we conservatively block all ObjectType (Java objects) regardless of + // serializability, because the type-level info with java.io.Serializable and + // java.io.Externalizable marker interfaces are not strong guarantees. + // This restriction can be relaxed in the future to expose more optimizations. + !dt.existsRecursively(_.isInstanceOf[ObjectType]) + } /** * Prepares codes for arguments. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e6605b496972b..90882da0cab3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -336,6 +336,51 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-40380: InvokeLike should only constant-fold to serializable types") { + val serializableObjType = ObjectType(classOf[SerializableBoxedInt]) + val notSerializableObjType = ObjectType(classOf[NotSerializableBoxedInt]) + + val originalQuery = + testRelation + .select( + // SerializableBoxedInt(42).add(1).toNotSerializable().addAsInt($"a") + Invoke( + Invoke( + Invoke( + Literal.fromObject(SerializableBoxedInt(42), serializableObjType), + "add", + serializableObjType, + Literal(1) :: Nil + ), + "toNotSerializable", + notSerializableObjType), + "addAsInt", + IntegerType, + $"a" :: Nil).as("c1")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + // If serializable ObjectType is allowed to be constant-folded in the future, this chain can + // be optimized into: + // val correctAnswer = + // testRelation + // .select( + // // SerializableBoxedInt(43).toNotSerializable().addAsInt($"a") + // Invoke( + // Invoke( + // Literal.fromObject(SerializableBoxedInt(43), serializableObjType), + // "toNotSerializable", + // notSerializableObjType), + // "addAsInt", + // IntegerType, + // $"a" :: Nil).as("c1")) + // .analyze + + comparePlans(optimized, correctAnswer) + } + test("SPARK-39106: Correct conditional expression constant folding") { val t = LocalRelation.fromExternalRows( $"c".double :: Nil, @@ -371,3 +416,12 @@ class ConstantFoldingSuite extends PlanTest { } } } + +case class SerializableBoxedInt(intVal: Int) { + def add(other: Int): SerializableBoxedInt = SerializableBoxedInt(intVal + other) + def toNotSerializable(): NotSerializableBoxedInt = new NotSerializableBoxedInt(intVal) +} + +class NotSerializableBoxedInt(intVal: Int) { + def addAsInt(other: Int): Int = intVal + other +}