Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput

def propagateNull: Boolean

override def foldable: Boolean = children.forall(_.foldable) && deterministic
override def foldable: Boolean =
children.forall(_.foldable) && deterministic && trustedSerializable(dataType)
protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true)
protected lazy val needNullCheckForIndex: Array[Boolean] =
arguments.map(a => a.nullable && (propagateNull ||
Expand All @@ -62,6 +63,14 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput
.map(cls => v => cls.cast(v))
.getOrElse(identity)

// Returns true if we can trust all values of the given DataType can be serialized.
private def trustedSerializable(dt: DataType): Boolean = {
// Right now we conservatively block all ObjectType (Java objects) regardless of
// serializability, because the type-level info with java.io.Serializable and
// java.io.Externalizable marker interfaces are not strong guarantees.
// This restriction can be relaxed in the future to expose more optimizations.
!dt.existsRecursively(_.isInstanceOf[ObjectType])
}

/**
* Prepares codes for arguments.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,51 @@ class ConstantFoldingSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("SPARK-40380: InvokeLike should only constant-fold to serializable types") {
val serializableObjType = ObjectType(classOf[SerializableBoxedInt])
val notSerializableObjType = ObjectType(classOf[NotSerializableBoxedInt])

val originalQuery =
testRelation
.select(
// SerializableBoxedInt(42).add(1).toNotSerializable().addAsInt($"a")
Invoke(
Invoke(
Invoke(
Literal.fromObject(SerializableBoxedInt(42), serializableObjType),
"add",
serializableObjType,
Literal(1) :: Nil
),
"toNotSerializable",
notSerializableObjType),
"addAsInt",
IntegerType,
$"a" :: Nil).as("c1"))

val optimized = Optimize.execute(originalQuery.analyze)

val correctAnswer = originalQuery.analyze

// If serializable ObjectType is allowed to be constant-folded in the future, this chain can
// be optimized into:
// val correctAnswer =
// testRelation
// .select(
// // SerializableBoxedInt(43).toNotSerializable().addAsInt($"a")
// Invoke(
// Invoke(
// Literal.fromObject(SerializableBoxedInt(43), serializableObjType),
// "toNotSerializable",
// notSerializableObjType),
// "addAsInt",
// IntegerType,
// $"a" :: Nil).as("c1"))
// .analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-39106: Correct conditional expression constant folding") {
val t = LocalRelation.fromExternalRows(
$"c".double :: Nil,
Expand Down Expand Up @@ -371,3 +416,12 @@ class ConstantFoldingSuite extends PlanTest {
}
}
}

case class SerializableBoxedInt(intVal: Int) {
def add(other: Int): SerializableBoxedInt = SerializableBoxedInt(intVal + other)
def toNotSerializable(): NotSerializableBoxedInt = new NotSerializableBoxedInt(intVal)
}

class NotSerializableBoxedInt(intVal: Int) {
def addAsInt(other: Int): Int = intVal + other
}