From 470380e48a9bf574ee6cfc2700bd044b70276cd8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 3 Sep 2016 10:26:52 -0700 Subject: [PATCH 1/4] Add regression test. --- core/src/test/scala/org/apache/spark/DistributedSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4ee0e00fde506..d1d83f7f3ba57 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -174,6 +174,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList assert(deserialized === (1 to 100).toList) } + // This will exercise the getRemoteBytes / getRemoteValues code paths: + blockManager.get(blockId) } Seq( From 9eb75f57bbb7ee0c555bbdd26cf4187ee0ad3671 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 3 Sep 2016 10:31:43 -0700 Subject: [PATCH 2/4] Fix bug by threading proper ClassTag --- .../scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 15 ++++++++------- .../scala/org/apache/spark/DistributedSuite.scala | 2 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 63d1d1767a8cb..d47b75544fdba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo assertValid() val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDPartition].blockId - blockManager.get(blockId) match { + blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => throw new Exception("Could not compute split, block " + blockId + " not found") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c72f28e00cdbc..0614646771bd0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -520,10 +520,11 @@ private[spark] class BlockManager( * * This does not acquire a lock on this block in this JVM. */ - private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { + private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { + val ct = implicitly[ClassTag[T]] getRemoteBytes(blockId).map { data => val values = - serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct) new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -602,13 +603,13 @@ private[spark] class BlockManager( * any locks if the block was fetched from a remote block manager. The read lock will * automatically be freed once the result's `data` iterator is fully consumed. */ - def get(blockId: BlockId): Option[BlockResult] = { + def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val local = getLocalValues(blockId) if (local.isDefined) { logInfo(s"Found block $blockId locally") return local } - val remote = getRemoteValues(blockId) + val remote = getRemoteValues[T](blockId) if (remote.isDefined) { logInfo(s"Found block $blockId remotely") return remote @@ -660,7 +661,7 @@ private[spark] class BlockManager( makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = { // Attempt to read the block from local or remote storage. If it's present, then we don't need // to go through the local-get-or-put path. - get(blockId) match { + get[T](blockId)(classTag) match { case Some(block) => return Left(block) case _ => @@ -1204,8 +1205,8 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: BlockId): Option[Any] = { - get(blockId).map(_.data.next()) + def getSingle[T: ClassTag](blockId: BlockId): Option[T] = { + get[T](blockId).map(_.data.next().asInstanceOf[T]) } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index d1d83f7f3ba57..f7beaa0d0fcb6 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -175,7 +175,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(deserialized === (1 to 100).toList) } // This will exercise the getRemoteBytes / getRemoteValues code paths: - blockManager.get(blockId) + assert(blockManager.get[Int](blockId).get.data.toSet === (1 to 1000).toSet) } Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 53fccd8d5e6ed..bf7b7cdc1c655 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( val blockId = partition.blockId def getBlockFromBlockManager(): Option[Iterator[T]] = { - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) + blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]]) } def getBlockFromWriteAheadLog(): Iterator[T] = { From 222a0ba88867b9f09d61f34be906e56ab0c60c73 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 4 Sep 2016 13:46:15 -0700 Subject: [PATCH 3/4] Fix test assertion. --- core/src/test/scala/org/apache/spark/DistributedSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index f7beaa0d0fcb6..6ea47b91ca34e 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -175,7 +175,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(deserialized === (1 to 100).toList) } // This will exercise the getRemoteBytes / getRemoteValues code paths: - assert(blockManager.get[Int](blockId).get.data.toSet === (1 to 1000).toSet) + assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet) } Seq( From 68db68dbbc3ad7ecfe8180b165155f643c92cf2b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 4 Sep 2016 13:59:06 -0700 Subject: [PATCH 4/4] Require ClassTag to be passed explicitly --- .../org/apache/spark/serializer/SerializerManager.scala | 7 ++++--- .../src/test/scala/org/apache/spark/DistributedSuite.scala | 4 ++-- .../spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 3 ++- .../apache/spark/streaming/ReceivedBlockHandlerSuite.scala | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 7b1ec6fcbbbf6..2156d576f1874 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * Deserializes an InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream[T: ClassTag]( + def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream): Iterator[T] = { + inputStream: InputStream) + (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(implicitly[ClassTag[T]]) + getSerializer(classTag) .newInstance() .deserializeStream(wrapStream(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 6ea47b91ca34e..4e36adc8baf3f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -170,8 +170,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = serializerManager.dataDeserializeStream[Int](blockId, - new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList + val deserialized = serializerManager.dataDeserializeStream(blockId, + new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) } // This will exercise the getRemoteBytes / getRemoteValues code paths: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index bf7b7cdc1c655..0b2ec298132ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( dataRead.rewind() } serializerManager - .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) + .dataDeserializeStream( + blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index feb5c30c6aa14..7e665454a5400 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps +import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} @@ -163,7 +164,7 @@ class ReceivedBlockHandlerSuite val bytes = reader.read(fileSegment) reader.close() serializerManager.dataDeserializeStream( - generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList + generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data }