Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object ScalaReflection extends ScalaReflection {
*/
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])

private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
tpe match {
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
Expand Down Expand Up @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
Expand Down Expand Up @@ -145,7 +145,7 @@ object ScalaReflection extends ScalaReflection {
private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {

/** Returns the current path with a sub-field extracted. */
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
Expand Down Expand Up @@ -452,7 +452,7 @@ object ScalaReflection extends ScalaReflection {
inputObject: Expression,
tpe: `Type`,
walkedTypePath: Seq[String],
seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized {
seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects {

def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
dataTypeFor(elementType) match {
Expand Down Expand Up @@ -638,7 +638,7 @@ object ScalaReflection extends ScalaReflection {
* Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that,
* we also treat [[DefinedByConstructorParams]] as product type.
*/
def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized {
def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Expand Down Expand Up @@ -700,7 +700,7 @@ object ScalaReflection extends ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
tpe match {
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Expand Down Expand Up @@ -766,7 +766,7 @@ object ScalaReflection extends ScalaReflection {
/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
def definedByConstructorParams(tpe: Type): Boolean = {
def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
}

Expand Down Expand Up @@ -795,6 +795,20 @@ trait ScalaReflection {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map

/**
* Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
* clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
* `scala.reflect.runtime.JavaUniverse.undoLog`.
*
* This method will also wrap `func` with `ScalaReflectionLock.synchronized` so the caller doesn't
* need to call it again.
*
* @see https://github.com/scala/bug/issues/8302
*/
def cleanUpReflectionObjects[T](func: => T): T = ScalaReflectionLock.synchronized {
universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
}

/**
* Return the Scala Type for `T` in the current classloader mirror.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ClosureCleaner

case class RepeatedStruct(s: Seq[PrimitiveData])

Expand Down Expand Up @@ -114,7 +115,9 @@ object ReferenceValueClass {
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)

implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects {
ExpressionEncoder()
}

// test flat encoders
encodeDecodeTest(false, "primitive boolean")
Expand Down Expand Up @@ -370,8 +373,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
test(s"encode/decode for $testName: $input") {
testAndVerifyNotLeakingReflectionObjects(s"encode/decode for $testName: $input") {
val encoder = implicitly[ExpressionEncoder[T]]

// Make sure encoder is serializable.
ClosureCleaner.clean((s: String) => encoder.getClass.getName)

val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolveAndBind()
Expand Down Expand Up @@ -441,4 +448,28 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}
}

/**
* Verify the size of scala.reflect.runtime.JavaUniverse.undoLog before and after `func` to
* ensure we don't leak Scala reflection garbage.
*
* @see org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects
*/
private def verifyNotLeakingReflectionObjects[T](func: => T): T = {
def undoLogSize: Int = {
scala.reflect.runtime.universe
.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.log.size
}

val previousUndoLogSize = undoLogSize
val r = func
assert(previousUndoLogSize == undoLogSize)
r
}

private def testAndVerifyNotLeakingReflectionObjects(testName: String)(testFun: => Any) {
test(testName) {
verifyNotLeakingReflectionObjects(testFun)
}
}
}