Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1204,9 +1204,17 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
* if not.
*
* @param f the closure to clean
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
* serializable
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you document checkSerializable in the doc? (like what exception does it throw)

ClosureCleaner.clean(f, checkSerializable)
f
}

Expand Down
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.Set
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.{Logging, SparkEnv, SparkException}

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -101,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}

def clean(func: AnyRef) {
def clean(func: AnyRef, checkSerializable: Boolean = true) {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -153,6 +153,18 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}

if (checkSerializable) {
ensureSerializable(func)
}
}

private def ensureSerializable(func: AnyRef) {
try {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
}

private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
Expand Down
14 changes: 10 additions & 4 deletions core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.util.NonSerializable

import java.io.NotSerializableException

// Common state shared by FailureSuite-launched tasks. We use a global object
// for this because any local variables used in the task closures will rightfully
// be copied for each task, so there's no other way for them to share state.
Expand Down Expand Up @@ -102,7 +104,8 @@ class FailureSuite extends FunSuite with LocalSparkContext {
results.collect()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
assert(thrown.getMessage.contains("NotSerializableException") ||
thrown.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}
Expand All @@ -116,21 +119,24 @@ class FailureSuite extends FunSuite with LocalSparkContext {
sc.parallelize(1 to 10, 2).map(x => a).count()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
assert(thrown.getMessage.contains("NotSerializableException") ||
thrown.getCause.getClass === classOf[NotSerializableException])

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))
assert(thrown1.getMessage.contains("NotSerializableException") ||
thrown1.getCause.getClass === classOf[NotSerializableException])

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
}
assert(thrown2.getClass === classOf[SparkException])
assert(thrown2.getMessage.contains("NotSerializableException"))
assert(thrown2.getMessage.contains("NotSerializableException") ||
thrown2.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}
Expand Down
75 changes: 52 additions & 23 deletions core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,29 @@ package org.apache.spark

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._

class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
// Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
test("basic inference of Orderings"){
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10)

// These RDD methods are in the companion object so that the unserializable ScalaTest Engine
// won't be reachable from the closure object

// Infer orderings after basic maps to particular types
val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd)
basicMapExpectations.map({case (met, explain) => assert(met, explain)})

// Infer orderings for other RDD methods
val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd)
otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)})
}
}

private object ImplicitOrderingSuite {
class NonOrderedClass {}

class ComparableClass extends Comparable[ComparableClass] {
Expand All @@ -31,27 +51,36 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
class OrderedClass extends Ordered[OrderedClass] {
override def compare(o: OrderedClass): Int = ???
}

// Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
test("basic inference of Orderings"){
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10)

// Infer orderings after basic maps to particular types
assert(rdd.map(x => (x, x)).keyOrdering.isDefined)
assert(rdd.map(x => (1, x)).keyOrdering.isDefined)
assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined)
assert(rdd.map(x => (null, x)).keyOrdering.isDefined)
assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty)
assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined)
assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined)

// Infer orderings for other RDD methods
assert(rdd.groupBy(x => x).keyOrdering.isDefined)
assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty)
assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined)
assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined)
assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined)
assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined)

def basicMapExpectations(rdd: RDD[Int]) = {
List((rdd.map(x => (x, x)).keyOrdering.isDefined,
"rdd.map(x => (x, x)).keyOrdering.isDefined"),
(rdd.map(x => (1, x)).keyOrdering.isDefined,
"rdd.map(x => (1, x)).keyOrdering.isDefined"),
(rdd.map(x => (x.toString, x)).keyOrdering.isDefined,
"rdd.map(x => (x.toString, x)).keyOrdering.isDefined"),
(rdd.map(x => (null, x)).keyOrdering.isDefined,
"rdd.map(x => (null, x)).keyOrdering.isDefined"),
(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty,
"rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"),
(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined,
"rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"),
(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined,
"rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined"))
}
}

def otherRDDMethodExpectations(rdd: RDD[Int]) = {
List((rdd.groupBy(x => x).keyOrdering.isDefined,
"rdd.groupBy(x => x).keyOrdering.isDefined"),
(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty,
"rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"),
(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined,
"rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"),
(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined,
"rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"),
(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined,
"rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined"),
(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined,
"rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer;

import java.io.NotSerializableException

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
import org.apache.spark.SharedSparkContext

/* A trivial (but unserializable) container for trivial functions */
class UnserializableClass {
def op[T](x: T) = x.toString

def pred[T](x: T) = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {

def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)

test("throws expected serialization exceptions on actions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
data.map(uc.op(_)).count
}

assert(ex.getMessage.contains("Task not serializable"))
}

// There is probably a cleaner way to eliminate boilerplate here, but we're
// iterating over a map from transformation names to functions that perform that
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
"mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
"mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
"mapPartitionsWithContext" -> xmapPartitionsWithContext _,
"filterWith" -> xfilterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
xf(data, uc)
}

assert(ex.getMessage.contains("Task not serializable"),
s"RDD.$name doesn't proactively throw NotSerializableException")
}
}

private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y=>uc.op(y))
private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))
private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))
private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y=>uc.op(y)))
private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))

}
Original file line number Diff line number Diff line change
Expand Up @@ -532,23 +532,32 @@ abstract class DStream[T: ClassTag] (
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean
new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
}

/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to set this to false for the tests? Is this because the tests don't have a serializer set up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, the issue here is that the DStream object is reachable from transformFunc(r ), and DStream.writeObject generates a runtime exception if it is called from within the closure serializer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document that inline (otherwise we will run into that again)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willb I think you missed this. Make sure you add comment above this line to explain the reason why we do not check serializable ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for all other instances where that is set to false too

}

/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
val cleanedF = context.sparkContext.clean(transformFunc)
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean
val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
Expand All @@ -563,7 +572,10 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}

Expand All @@ -574,7 +586,10 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]
Expand Down