From d5947b33f8e22ee498473abd59795d4f15a7b198 Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 14 Mar 2014 11:40:56 -0500 Subject: [PATCH 01/12] Ensure assertions in Graph.apply are asserted. The Graph.apply test in GraphSuite had some assertions in a closure in a graph transformation. This caused two problems: 1. because assert() was called, test classes were reachable from the closures, which made them not serializable, and 2. (more importantly) these assertions never actually executed, since they occurred within a lazy map() This commit simply changes the Graph.apply test to collects the graph triplets so it can assert about each triplet from a map method. --- graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 28d34dd9a1a41..c65e36636fe10 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.map { et => + graph.triplets.collect.map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } From 21b4b063372bfbdd289ff770d3d38cb4453e7ca6 Mon Sep 17 00:00:00 2001 From: William Benton Date: Wed, 12 Mar 2014 21:56:32 -0500 Subject: [PATCH 02/12] Test cases for SPARK-897. Tests to make sure that passing an unserializable closure to a transformation fails fast. --- .../ProactiveClosureSerializationSuite.scala | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala new file mode 100644 index 0000000000000..aa25c96a637aa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.NotSerializableException + +import org.scalatest.FunSuite + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkException +import org.apache.spark.SharedSparkContext + +/* A trivial (but unserializable) container for trivial functions */ +class UnserializableClass { + def op[T](x: T) = x.toString + + def pred[T](x: T) = x.toString.length % 2 == 0 +} + +class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { + + def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + + test("throws expected serialization exceptions on actions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + data.map(uc.op(_)).count + } + + assert(ex.getMessage.matches(".*Task not serializable.*")) + } + + // There is probably a cleaner way to eliminate boilerplate here, but we're + // iterating over a map from transformation names to functions that perform that + // transformation on a given RDD, creating one test case for each + + for (transformation <- + Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapWith" -> xmapWith _, + "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, + "mapPartitionsWithContext" -> xmapPartitionsWithContext _, "filterWith" -> xfilterWith _)) { + val (name, xf) = transformation + + test(s"$name transformations throw proactive serialization exceptions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + xf(data, uc) + } + + assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException") + } + } + + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y=>uc.op(y)) + private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y=>Seq(uc.op(y))) + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y=>uc.pred(y)) + private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y=>uc.op(y))) + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) + private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) + +} From d8df3dbbed3e1c9e16a2c2002c04c2d2e6a0b2e2 Mon Sep 17 00:00:00 2001 From: William Benton Date: Thu, 13 Mar 2014 14:40:42 -0500 Subject: [PATCH 03/12] Adds proactive closure-serializablilty checking ClosureCleaner.clean now checks to ensure that its closure argument is serializable by default and throws a SparkException with the underlying NotSerializableException in the detail message otherwise. As a result, transformation invocations with unserializable closures will fail at their call sites rather than when they actually execute. ClosureCleaner.clean now takes a second boolean argument; pass false to disable serializability-checking behavior at call sites where this behavior isn't desired. --- .../org/apache/spark/util/ClosureCleaner.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index a8d20ee332355..fe33fa841a0d8 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -26,6 +26,8 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, Cl import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import org.apache.spark.SparkException private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -101,7 +103,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef) { + def clean(func: AnyRef, checkSerializable: Boolean = true) { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -150,6 +152,18 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(func, outer) } + + if (checkSerializable) { + ensureSerializable(func) + } + } + + private def ensureSerializable(func: AnyRef) { + try { + SparkEnv.get.closureSerializer.newInstance().serialize(func) + } catch { + case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString) + } } private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { From 4ecf84100e22224aed204c3c4251c6ab20ff8bf6 Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 14 Mar 2014 12:33:33 -0500 Subject: [PATCH 04/12] Make proactive serializability checking optional. SparkContext.clean uses ClosureCleaner's proactive serializability checking by default. This commit adds an overloaded clean method to SparkContext that allows clients to specify that serializability checking should not occur as part of closure cleaning. --- .../src/main/scala/org/apache/spark/SparkContext.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1003b7925715..c6f3b7a8494f8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1032,7 +1032,15 @@ class SparkContext( * (removes unreferenced variables in $outer's, updates REPL variables) */ private[spark] def clean[F <: AnyRef](f: F): F = { - ClosureCleaner.clean(f) + clean(f, true) + } + + /** + * Clean a closure to make it ready to serialized and send to tasks + * (removes unreferenced variables in $outer's, updates REPL variables) + */ + private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean): F = { + ClosureCleaner.clean(f, checkSerializable) f } From d6e8dd6469ef24ee0631d7c5bb424498715c59f5 Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 14 Mar 2014 12:34:42 -0500 Subject: [PATCH 05/12] Don't check serializability of DStream transforms. Since the DStream is reachable from within these closures, they aren't checkable by the straightforward technique of passing them to the closure serializer. --- .../org/apache/spark/streaming/dstream/DStream.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 6bff56a9d332a..dc79483ddaccb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -533,7 +533,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r))) + transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false)) } /** @@ -541,7 +541,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - val cleanedF = context.sparkContext.clean(transformFunc) + val cleanedF = context.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) @@ -556,7 +556,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc) + val cleanedF = ssc.sparkContext.clean(transformFunc, false) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) } @@ -567,7 +567,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc) + val cleanedF = ssc.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 2) val rdd1 = rdds(0).asInstanceOf[RDD[T]] From 12c63a7e03bce359fd7eb7faf0a054bd32f85824 Mon Sep 17 00:00:00 2001 From: William Benton Date: Tue, 18 Mar 2014 09:55:57 -0500 Subject: [PATCH 06/12] Added tests for variable capture in closures The two tests added to ClosureCleanerSuite ensure that variable values are captured at RDD definition time, not at job-execution time. --- .../spark/util/ClosureCleanerSuite.scala | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 439e5644e20a3..24308820ea240 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -50,6 +50,20 @@ class ClosureCleanerSuite extends FunSuite { val obj = new TestClassWithNesting(1) assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 } + + test("capturing free variables in closures at RDD definition") { + val obj = new TestCaptureVarClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } + + test("capturing free variable fields in closures at RDD definition") { + val obj = new TestCaptureFieldClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -143,3 +157,37 @@ class TestClassWithNesting(val y: Int) extends Serializable { } } } + +class TestCaptureFieldClass extends Serializable { + class ZeroBox extends Serializable { + var zero = 0 + } + + def run(): (Int, Int) = { + val zb = new ZeroBox + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zb.zero) + + zb.zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +} + +class TestCaptureVarClass extends Serializable { + def run(): (Int, Int) = { + var zero = 0 + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zero) + + zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +} From 8ee3ee73608545b6ee3ae966650afca55d0bf347 Mon Sep 17 00:00:00 2001 From: William Benton Date: Thu, 20 Mar 2014 10:48:17 -0500 Subject: [PATCH 07/12] Predictable closure environment capture The environments of serializable closures are now captured as part of closure cleaning. Since we already proactively check most closures for serializability, ClosureCleaner.clean now returns the result of deserializing the serialized version of the cleaned closure. --- .../main/scala/org/apache/spark/SparkContext.scala | 5 ++--- .../scala/org/apache/spark/util/ClosureCleaner.scala | 11 ++++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c6f3b7a8494f8..f3c5b420db8a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1031,7 +1031,7 @@ class SparkContext( * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ - private[spark] def clean[F <: AnyRef](f: F): F = { + private[spark] def clean[F <: AnyRef : ClassTag](f: F): F = { clean(f, true) } @@ -1039,9 +1039,8 @@ class SparkContext( * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ - private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean): F = { + private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean): F = { ClosureCleaner.clean(f, checkSerializable) - f } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index fe33fa841a0d8..2db06548a1ac2 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,6 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set +import scala.reflect.ClassTag + import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ @@ -103,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef, checkSerializable: Boolean = true) { + def clean[F <: AnyRef : ClassTag](func: F, checkSerializable: Boolean = true): F = { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -155,12 +157,15 @@ private[spark] object ClosureCleaner extends Logging { if (checkSerializable) { ensureSerializable(func) + } else { + func } } - private def ensureSerializable(func: AnyRef) { + private def ensureSerializable[T: ClassTag](func: T) = { try { - SparkEnv.get.closureSerializer.newInstance().serialize(func) + val serializer = SparkEnv.get.closureSerializer.newInstance() + serializer.deserialize[T](serializer.serialize[T](func)) } catch { case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString) } From 12ef6e3af0a2f8df2375e774c9bfcc313fb84d8a Mon Sep 17 00:00:00 2001 From: William Benton Date: Tue, 25 Mar 2014 23:45:45 -0500 Subject: [PATCH 08/12] Skip proactive closure capture for runJob There are two possible cases for runJob calls: either they are called by RDD action methods from inside Spark or they are called from client code. There's no need to proactively check the closure argument to runJob for serializability or force variable capture in either case: 1. if they are called by RDD actions, their closure arguments consist of mapping an already-serializable closure (with an already-frozen environment) to each element in the RDD; 2. in both cases, the closure is about to execute and thus the benefit of proactively checking for serializability (or ensuring immediate variable capture) is nonexistent. (Note that ensuring capture via serializability on closure arguments to runJob also causes pyspark accumulators to fail to update.) --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 +++- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f3c5b420db8a2..4c5b35d4025bf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -896,7 +896,9 @@ class SparkContext( require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p") } val callSite = getCallSite - val cleanedFunc = clean(func) + // There's no need to check this function for serializability, + // since it will be run right away. + val cleanedFunc = clean(func, false) logInfo("Starting job: " + callSite) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1b43040c6d918..42a1e3faec722 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -640,14 +640,16 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** From 97e9d916a3c213b6152cb4513caa19a1ae56df4e Mon Sep 17 00:00:00 2001 From: William Benton Date: Wed, 26 Mar 2014 11:31:56 -0500 Subject: [PATCH 09/12] Split closure-serializability failure tests This splits the test identifying expected failures due to closure serializability into three cases. --- .../scala/org/apache/spark/FailureSuite.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index f3fb64d87a2fd..1efbcab48adaf 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } - test("failure because task closure is not serializable") { + test("failure because closure in final-stage task is not serializable") { sc = new SparkContext("local[1,1]", "test") val a = new NonSerializable @@ -118,6 +118,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) + FailureSuiteState.clear() + } + + test("failure because closure in early-stage task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in an earlier stage val thrown1 = intercept[SparkException] { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() @@ -125,6 +132,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown1.getClass === classOf[SparkException]) assert(thrown1.getMessage.contains("NotSerializableException")) + FailureSuiteState.clear() + } + + test("failure because closure in foreach task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { sc.parallelize(1 to 10, 2).foreach(x => println(a)) @@ -135,6 +149,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } + // TODO: Need to add tests with shuffle fetch failures. } From 9b56ce0e7a81b3e40713ef4a59bd96f18176a626 Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 4 Apr 2014 16:32:13 -0500 Subject: [PATCH 10/12] Added array-element capture test --- .../spark/util/ClosureCleanerSuite.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 24308820ea240..c635da6cacd70 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -64,6 +64,13 @@ class ClosureCleanerSuite extends FunSuite { assert(ones === onesPlusZeroes) } + + test("capturing arrays in closures at RDD definition") { + val obj = new TestCaptureArrayEltClass() + val (observed, expected) = obj.run() + + assert(observed === expected) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -177,6 +184,19 @@ class TestCaptureFieldClass extends Serializable { } } +class TestCaptureArrayEltClass extends Serializable { + def run(): (Int, Int) = { + withSpark(new SparkContext("local", "test")) {sc => + val rdd = sc.parallelize(1 to 10) + val data = Array(1, 2, 3) + val expected = data(0) + val mapped = rdd.map(x => data(0)) + data(0) = 4 + (mapped.first, expected) + } + } +} + class TestCaptureVarClass extends Serializable { def run(): (Int, Int) = { var zero = 0 From b3d9c8656ac6bce766447b2324f9e5728f79e04d Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 4 Apr 2014 16:39:55 -0500 Subject: [PATCH 11/12] Fixed style issues in tests --- .../ProactiveClosureSerializationSuite.scala | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index aa25c96a637aa..76662264e7e94 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -51,9 +51,9 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex // transformation on a given RDD, creating one test case for each for (transformation <- - Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapWith" -> xmapWith _, - "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _, "filterWith" -> xfilterWith _)) { + Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _, + "mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _, + "mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { @@ -67,13 +67,28 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex } } - private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y=>uc.op(y)) - private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) - private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y=>Seq(uc.op(y))) - private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y=>uc.pred(y)) - private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) - private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y=>uc.op(y))) - private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) + def map(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.map(y => uc.op(y)) + + def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapWith(x => x.toString)((x,y) => x + uc.op(y)) + + def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.flatMap(y=>Seq(uc.op(y))) + + def filter(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filter(y=>uc.pred(y)) + + def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filterWith(x => x.toString)((x,y) => uc.pred(y)) + + def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitions(_.map(y => uc.op(y))) + + def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) + + def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y))) } From f4cafa0712be5011ff9dc7e5e6d27a7077095b28 Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 4 Apr 2014 17:15:50 -0500 Subject: [PATCH 12/12] Stylistic changes and cleanups --- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++++----------- .../apache/spark/util/ClosureCleaner.scala | 8 ++++---- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4c5b35d4025bf..5339d97bb5fb1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1028,21 +1028,18 @@ class SparkContext( def cancelAllJobs() { dagScheduler.cancelAllJobs() } - - /** - * Clean a closure to make it ready to serialized and send to tasks - * (removes unreferenced variables in $outer's, updates REPL variables) - */ - private[spark] def clean[F <: AnyRef : ClassTag](f: F): F = { - clean(f, true) - } - + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) + * + * @param f closure to be cleaned and optionally serialized + * @param captureNow whether or not to serialize this closure and capture any free + * variables immediately; defaults to true. If this is set and f is not serializable, + * it will raise an exception. */ - private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean): F = { - ClosureCleaner.clean(f, checkSerializable) + private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = { + ClosureCleaner.clean(f, captureNow) } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 2db06548a1ac2..432278e17e13d 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -105,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean[F <: AnyRef : ClassTag](func: F, checkSerializable: Boolean = true): F = { + def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -155,14 +155,14 @@ private[spark] object ClosureCleaner extends Logging { field.set(func, outer) } - if (checkSerializable) { - ensureSerializable(func) + if (captureNow) { + cloneViaSerializing(func) } else { func } } - private def ensureSerializable[T: ClassTag](func: T) = { + private def cloneViaSerializing[T: ClassTag](func: T): T = { try { val serializer = SparkEnv.get.closureSerializer.newInstance() serializer.deserialize[T](serializer.serialize[T](func))