diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 8f70744d804d..4bde2c5697e4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -146,6 +146,9 @@ object SparkEnv extends Logging { // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") + // When running tests in local mode, previous shutdown of sc could have marked it as + // VM shutdown. recheck and disable shutdown flag. + // Utils.doShutdownCheck() } val securityManager = new SecurityManager(conf) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 0a7e1ec53967..fff68661cca3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -40,7 +40,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In */ def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) - if (counterReset > 0 && counter >= counterReset) { + if (counterReset >= 0 && counter >= counterReset) { objOut.reset() counter = 0 } else { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 9b78228519da..7a44ae880df0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -61,7 +61,8 @@ class HashShuffleWriter[K, V]( } /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { + override def stop(successInput: Boolean): Option[MapStatus] = { + var success = successInput try { if (stopping) { return None @@ -71,7 +72,8 @@ class HashShuffleWriter[K, V]( try { return Some(commitWritesAndBuildStatus()) } catch { - case e: Exception => + case e: Throwable => + success = false // for finally block revertWrites() throw e } @@ -96,9 +98,9 @@ class HashShuffleWriter[K, V]( var totalBytes = 0L var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() + writer.commitAndClose() val size = writer.fileSegment().length + assert(size >= 0) totalBytes += size totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) @@ -116,8 +118,13 @@ class HashShuffleWriter[K, V]( private def revertWrites(): Unit = { if (shuffle != null && shuffle.writers != null) { for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + try { + writer.revertPartialWritesAndClose() + } catch { + // Ensure that all revert's get done - log exception and continue + case ex: Exception => + logError("Exception reverting/closing writers", ex) + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a2687e6be4e3..623405bc8e61 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,11 +17,13 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.io.{BufferedOutputStream, File, FileNotFoundException, FileOutputStream} +import java.io.{SyncFailedException, IOException, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -32,23 +34,17 @@ import org.apache.spark.serializer.{SerializationStream, Serializer} */ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - /** * Flush the partial writes and commit them as a single atomic block. Return the * number of bytes written for this commit. */ - def commit(): Long + def commitAndClose() /** * Reverts writes that haven't been flushed yet. Callers should invoke this function * when there are runtime exceptions. */ - def revertPartialWrites() + def revertPartialWritesAndClose() /** * Writes an object. @@ -71,7 +67,13 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def bytesWritten: Long } -/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +/** + * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * Note, this impl is NOT MT-safe : use external synchronization if you need it to be so. + * + * TODO: Some of the asserts, particularly which use File.* methods can be expensive - ensure they + * are not in critical path. + */ private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, @@ -107,68 +109,296 @@ private[spark] class DiskBlockObjectWriter( private var fos: FileOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null + + // Did we create this file or was it already present : used in revert to decide + // if we should delete this file or not. Also used to detect if file was deleted + // between creation of BOW and its actual init + private val initiallyExists = file.exists() && file.isFile private val initialPosition = file.length() private var lastValidPosition = initialPosition + private var initialized = false + // closed explicitly ? + private var closed = false + // Attempt to cleanly close ? (could also be closed via revert) + // Note, a cleanly closed file could be subsequently reverted + private var cleanCloseAttempted = false + // Was the file actually opened atleast once. + // Note: initialized/streams change state with close/revert. + private var wasOpenedOnce = false private var _timeWriting = 0L - override def open(): BlockObjectWriter = { - fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) - channel = fos.getChannel() + // Due to some directory creation race issues in spark, it has been observed that + // sometimes file creation happens 'before' the actual directory has been created + // So we attempt to retry atleast once with a mkdirs in case directory was missing. + private def init() { + init(canRetry = true) + } + + private def init(canRetry: Boolean) { + + if (closed) throw new IOException("Already closed") + + assert(! initialized) + assert(! wasOpenedOnce) + var exists = false + try { + exists = file.exists() + if (! exists && initiallyExists && 0 != initialPosition && ! Utils.inShutdown) { + // Was deleted by cleanup thread ? + throw new IOException("file " + file + " cleaned up ? exists = " + exists + + ", initiallyExists = " + initiallyExists + ", initialPosition = " + initialPosition) + } + fos = new FileOutputStream(file, true) + } catch { + case fEx: FileNotFoundException => + // There seems to be some race in directory creation. + // Attempts to fix it dont seem to have worked : working around the problem for now. + logDebug("Unable to open " + file + ", canRetry = " + canRetry + ", exists = " + exists + + ", initialPosition = " + initialPosition + ", in shutdown = " + Utils.inShutdown(), fEx) + if (canRetry && ! Utils.inShutdown()) { + // try creating the parent directory if that is the issue. + // Since there can be race with others, dont bother checking for + // success/failure - the call to init() will resolve if fos can be created. + file.getParentFile.mkdirs() + // Note, if directory did not exist, then file does not either - and so + // initialPosition would be zero in either case. + init(canRetry = false) + return + } else throw fEx + } + + try { + // This is to workaround case where creation of object and actual init + // (which can happen much later) happens after a delay and the cleanup thread + // cleaned up the file. + channel = fos.getChannel + val fosPos = channel.position() + if (initialPosition != fosPos) { + throw new IOException("file cleaned up ? " + file.exists() + + ", initialpos = " + initialPosition + + "current len = " + fosPos + ", in shutdown ? " + Utils.inShutdown) + } + + ts = new TimeTrackingOutputStream(fos) + val bos = new BufferedOutputStream(ts, bufferSize) + bs = compressStream(bos) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + wasOpenedOnce = true; + } finally { + if (! initialized) { + // failed, cleanup state. + val tfos = fos + updateCloseState() + tfos.close() + } + } + } + + private def open(): BlockObjectWriter = { + init() lastValidPosition = initialPosition - bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true this } - override def close() { - if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() + private def updateCloseState() { + + if (null != ts) _timeWriting += ts.timeWriting + + bs = null + channel = null + fos = null + ts = null + objOut = null + initialized = false + } + + private def flushAll() { + if (closed) throw new IOException("Already closed") + + // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the + // serializer stream and the lower level stream. + if (null != objOut) { + objOut.flush() + bs.flush() + } + } + + private def closeAll(needFlush: Boolean, needRevert: Boolean) { + + if (null != objOut) { + val truncatePos = if (needRevert) initialPosition else -1L + assert(! this.closed) + + // In case syncWrites is true or we need to truncate + var cleanlyClosed = false + try { + // Flushing if we need to truncate also. Currently, we reopen to truncate + // so this is not strictly required (since close could write further to streams). + // Keeping it around in case that gets relaxed. + if (needFlush || needRevert) flushAll() + val start = System.nanoTime() - fos.getFD.sync() + try { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + fos.getFD.sync() + } + } catch { + case sfe: SyncFailedException => // ignore + } + // must cause cascading close. Note, repeated close on closed streams should not cause + // issues : except some libraries do not honour it - hence not explicitly closing bs/fos + objOut.close() + // bs.close() + // fos.close() _timeWriting += System.nanoTime() - start - } - objOut.close() - _timeWriting += ts.timeWriting + // fos MUST have been closed. + assert(null == channel || !channel.isOpen) + cleanlyClosed = true + + } finally { + + this.closed = true + if (! cleanlyClosed) { + // could not cleanly close. We have two cases here - + // a) normal close, + // b) revert + // If (a) then then streams/data is in inconsistent so we cant really recover + // simply release fd and allow exception to bubble up. + // If (b) and file length >= initialPosition, then truncate file and ignore exception + // else,cause exception to bubble up since we cant recover + assert(null != fos) + try { fos.close() } catch { case ioEx: IOException => /* best case attempt, ignore */ } + } + + updateCloseState() + + // Since close can end up writing data in general case (inspite of flush), + // we reopen to truncate file. + if (needRevert) { + // remove if not earlier existed : best case effort so we dont care about return value + // of delete (it can fail if file was already deleted by cleaner threads for example) + if (! initiallyExists) { + file.delete() + // Explicitly ignore exceptions (when cleanlyClosed = false) and return + // from here. Usually not good idea in finally, but it is ok here. + return + } else { + val fileLen = file.length() + if (fileLen >= truncatePos) { + if (fileLen > truncatePos) DiskBlockObjectWriter.truncateIfExists(file, truncatePos) - channel = null - bs = null - fos = null - ts = null - objOut = null - initialized = false + // Validate length. + assert(truncatePos == file.length() || Utils.inShutdown(), + "truncatePos = " + truncatePos + ", len = " + file.length() + + ", in shutdown = " + Utils.inShutdown()) + + // Explicitly ignore exceptions (when cleanlyClosed = false) and return + // from here. Usually not good idea in finally, but it is ok here. + return + } // else cause the exception to bubble up if thrown + } + } + } + } else { + // it is possible for open to have never been called - no data written to this + // partition for example. so objOut == null + this.closed = true } + initialized = false } - override def isOpen: Boolean = objOut != null + private def validateBytesWritten() { + // This should happen due to file deletion, during cleanup. Ensure bytesWritten is in sane + // state. Note, parallel threads continue to run while shutdown threads are running : so + // this prevents unwanted assertion failures and exception elsewhere. + if (lastValidPosition < initialPosition) { + // This is invoked so that assertions within bytes written are validated. + assert(bytesWritten >= 0) + lastValidPosition = initialPosition + } + } - override def commit(): Long = { + override def commitAndClose() { if (initialized) { - // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the - // serializer stream and the lower level stream. - objOut.flush() - bs.flush() + // opened, file still open + assert(wasOpenedOnce) + // Note, set cleanCloseAttempted even before we finish the close : so that a revert on this + // in case close fails can truncate to previous state ! + cleanCloseAttempted = true + closeAll(needFlush = true, needRevert = false) + val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + assert(prevPos == initialPosition) + assert(null == fos) + + lastValidPosition = file.length() + validateBytesWritten() + // review: remove ? + assert(bytesWritten >= 0, "bytesWritten = " + bytesWritten + + ", initial pos = " + initialPosition + ", last valid pos = " + lastValidPosition) + + } else if (cleanCloseAttempted) { + // opened and closed cleanly + assert(closed) + assert(wasOpenedOnce) + // size should be lastValidPosition, or file deleted due to shutdown. + assert(lastValidPosition == file.length() || Utils.inShutdown, + "lastValidPosition = " + lastValidPosition + + ", file len = " + file.length() + ", exists = " + file.exists()) + } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition + // reverted or never opened. + this.closed = true + assert(initialPosition == file.length() || (0 == initialPosition && ! initiallyExists) || + Utils.inShutdown, "initialPosition = " + initialPosition + + ", file len = " + file.length() + ", exists = " + file.exists()) + assert(lastValidPosition == initialPosition) } } - override def revertPartialWrites() { + override def revertPartialWritesAndClose() { if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + // opened, file still open + // Discard current writes. We do this by truncating the file to the last valid position. + closeAll(needFlush = true, needRevert = true) + validateBytesWritten() + assert(bytesWritten == 0, "bytesWritten = " + bytesWritten + + ", initial pos = " + initialPosition + ", last valid pos = " + lastValidPosition) + assert(initialPosition == file.length() || Utils.inShutdown, + "initialPosition = " + initialPosition + + ", file len = " + file.length() + ", exists = " + file.exists()) + } else if (cleanCloseAttempted) { + // Already opened and closed : truncate to last location (or delete + // if created in this instance) + assert(closed) + cleanCloseAttempted = false + + // truncate to initialPosition + // remove if not earlier existed + if (! initiallyExists) { + // best case effort so we dont care about return value + // of delete (it can fail if file was already deleted by cleaner threads for example) + file.delete() + } else if (file.exists()) { + DiskBlockObjectWriter.truncateIfExists(file, initialPosition) + } + // reset position. + lastValidPosition = initialPosition + + + assert(file.length() == initialPosition || Utils.inShutdown, + "initialPosition = " + initialPosition + + ", file len = " + file.length() + ", exists = " + file.exists()) + } else { + this.closed = true + assert(initialPosition == file.length() || (0 == initialPosition && ! initiallyExists) || + Utils.inShutdown, + "initialPosition = " + initialPosition + + ", file len = " + file.length() + ", exists = " + file.exists()) } } @@ -176,10 +406,17 @@ private[spark] class DiskBlockObjectWriter( if (!initialized) { open() } + // Not checking if closed on purpose ... introduce it ? No usecase for it right now. objOut.writeObject(value) } override def fileSegment(): FileSegment = { + assert(! initialized) + assert(null == fos) + assert(wasOpenedOnce || 0 == bytesWritten, + "wasOpenedOnce = " + wasOpenedOnce + ", initialPosition = " + initialPosition + + ", bytesWritten = " + bytesWritten + ", file len = " + file.length()) + new FileSegment(file, initialPosition, bytesWritten) } @@ -188,6 +425,39 @@ private[spark] class DiskBlockObjectWriter( // Only valid if called after commit() override def bytesWritten: Long = { - lastValidPosition - initialPosition + val retval = lastValidPosition - initialPosition + + assert(retval >= 0 || Utils.inShutdown(), + "exists = " + file.exists() + ", bytesWritten = " + retval + + ", lastValidPosition = " + lastValidPosition + ", initialPosition = " + initialPosition + + ", in shutdown = " + Utils.inShutdown()) + + // TODO: Comment this out when we are done validating : can be expensive due to file.length() + assert(file.length() >= lastValidPosition || Utils.inShutdown(), + "exists = " + file.exists() + ", file len = " + file.length() + + ", bytesWritten = " + retval + ", lastValidPosition = " + lastValidPosition + + ", initialPosition = " + initialPosition + ", in shutdown = " + Utils.inShutdown()) + + if (retval >= 0) retval else 0 } } + +object DiskBlockObjectWriter{ + + // Unfortunately, cant do it atomically ... + private def truncateIfExists(file: File, truncatePos: Long) { + var fos: FileOutputStream = null + try { + // There is no way to do this atomically iirc. + if (file.exists() && file.length() != truncatePos) { + fos = new FileOutputStream(file, true) + fos.getChannel.truncate(truncatePos) + } + } finally { + if (null != fos) { + fos.close() + } + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2e7ed7538e6e..01a82a767fb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} @@ -73,19 +73,22 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD val dirId = hash % localDirs.length val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + // prevent DCL: Other (expensive) option is to proactively create all directories ... // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.isDirectory()) { + val created = newDir.mkdirs() + if (!created && !newDir.isDirectory) { + throw new IOException("Unable to create directory " + newDir.getAbsolutePath) + } } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -114,6 +117,8 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD while (getFile(blockId).exists()) { blockId = new TempBlockId(UUID.randomUUID()) } + // Note, since the file is not created, theoretically the while loop can return the same file + // Practically though, since we use UUID, this should not happen. (blockId, getFile(blockId)) } @@ -162,8 +167,9 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { + logInfo("shutting down DiskBlockManager") localDirs.foreach { localDir => - if (localDir.isDirectory() && localDir.exists()) { + if (localDir.exists() && localDir.isDirectory()) { try { if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 35910e552fe8..86d75553d6dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.serializer.Serializer import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ @@ -74,7 +74,6 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. */ private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() @@ -85,6 +84,11 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val completedMapTasks = new ConcurrentLinkedQueue[Int]() } + // Ensure the nextFileId is globally unique, It is ok if it wraps around after 'long' time. + private object ShuffleState { + val nextFileId = new AtomicInteger(0) + } + type ShuffleId = Int private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] @@ -124,7 +128,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { if (consolidateShuffleFiles) { if (success) { val offsets = writers.map(_.fileSegment().offset) - fileGroup.recordMapOutput(mapId, offsets) + val lengths = writers.map(_.fileSegment().length) + fileGroup.recordMapOutput(mapId, offsets, lengths) } recycleFileGroup(fileGroup) } else { @@ -138,7 +143,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() + val fileId = ShuffleState.nextFileId.getAndIncrement() val files = Array.tabulate[File](numBuckets) { bucketId => val filename = physicalFileName(shuffleId, bucketId, fileId) blockManager.diskBlockManager.getFile(filename) @@ -236,31 +241,61 @@ object ShuffleBlockManager { new PrimitiveVector[Long]() } - def numBlocks = mapIdToIndex.size + /* + * This is required for shuffle consolidation to work. In particular when updates to file are + * happening while parallel requests to fetch block happens. + */ + private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } + + private var numBlocks = 0 def apply(bucketId: Int) = files(bucketId) - def recordMapOutput(mapId: Int, offsets: Array[Long]) { + def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { mapIdToIndex(mapId) = numBlocks for (i <- 0 until offsets.length) { + assert(blockOffsetsByReducer(i).size == numBlocks) + assert(offsets(i) >= 0) + assert(blockOffsetsByReducer(i).size <= 0 || + blockOffsetsByReducer(i)(blockOffsetsByReducer(i).size - 1) + + blockLengthsByReducer(i)(blockLengthsByReducer(i).size - 1) == offsets(i), + "Failed for " + i + ", blockOffsetsByReducer = " + blockOffsetsByReducer(i) + + ", blockLengthsByReducer = " + blockLengthsByReducer(i)) blockOffsetsByReducer(i) += offsets(i) + blockLengthsByReducer(i) += lengths(i) + assert(files(i).length() == lengths(i) + offsets(i) || Utils.inShutdown(), + "file = " + files(i).getAbsolutePath + ", offset = " + offsets(i) + + ", length = " + lengths(i) + ", file len = " + files(i).length()) } + numBlocks += 1 } /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) val index = mapIdToIndex.getOrElse(mapId, -1) if (index >= 0) { + val file = files(reducerId) + val blockOffsets = blockOffsetsByReducer(reducerId) + val blockLengths = blockLengthsByReducer(reducerId) val offset = blockOffsets(index) - val length = - if (index + 1 < numBlocks) { - blockOffsets(index + 1) - offset - } else { - file.length() - offset - } + val length = blockLengths(index) + + assert(offset >= 0) assert(length >= 0) + assert(blockOffsets.size <= index + 1 || blockOffsets(index + 1) == offset + length, + "Failed for reducerId = " + reducerId + " index = " + index + ", offset = " + offset + + ", length = " + length + ", file exists = " + file.exists() + + ", blockOffsetsByReducer = " + blockOffsetsByReducer(reducerId) + + ", blockLengthsByReducer = " + blockLengthsByReducer(reducerId) + + ", file len = " + file.length() + ", file = " + file.getAbsolutePath) + assert(file.length() >= offset + length || (! file.exists() && Utils.inShutdown()), + "Failed for reducerId = " + reducerId + " index = " + index + ", offset = " + offset + + ", length = " + length + ", file exists = " + file.exists() + + ", blockOffsetsByReducer = " + blockOffsetsByReducer(reducerId) + + ", blockLengthsByReducer = " + blockLengthsByReducer(reducerId) + + ", file len = " + file.length() + ", file = " + file.getAbsolutePath) Some(new FileSegment(file, offset, length)) } else { None diff --git a/core/src/main/scala/org/apache/spark/util/Java7Util.scala b/core/src/main/scala/org/apache/spark/util/Java7Util.scala new file mode 100644 index 000000000000..84c3c5d7ac5e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Java7Util.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File + +/** + * Java 7 (or higher) specific util methods. + */ +object Java7Util { + def isSymlink(file: File) = java.nio.file.Files.isSymbolicLink(file.toPath) +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5784e974fbb6..074c96b1687d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,9 +31,9 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} +import org.apache.commons.lang3.{JavaVersion, SystemUtils} import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} @@ -264,7 +264,7 @@ private[spark] object Utils extends Logging { try { dir = new File(root, "spark-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { - dir = null + if (!dir.isDirectory) dir = null } } catch { case e: IOException => ; } } @@ -274,7 +274,6 @@ private[spark] object Utils extends Logging { // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) @@ -339,17 +338,18 @@ private[spark] object Utils extends Logging { logDebug("fetchFile with security enabled") val newuri = constructURIForAuthentication(uri, securityMgr) uc = newuri.toURL().openConnection() - uc.setAllowUserInteraction(false) } else { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 + uc.setAllowUserInteraction(false) uc.setConnectTimeout(timeout) uc.setReadTimeout(timeout) uc.connect() - val in = uc.getInputStream() + val len = uc.getContentLengthLong + val in = if (len < 0) uc.getInputStream else ByteStreams.limit(uc.getInputStream, len) val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { @@ -618,18 +618,22 @@ private[spark] object Utils extends Logging { * Check to see if file is a symbolic link. */ def isSymlink(file: File): Boolean = { - if (file == null) throw new NullPointerException("File must not be null") - if (isWindows) return false - val fileInCanonicalDir = if (file.getParent() == null) { - file + if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_1_7)) { + Java7Util.isSymlink(file) } else { - new File(file.getParentFile().getCanonicalFile(), file.getName()) - } + if (file == null) throw new NullPointerException("File must not be null") + if (isWindows) return false + val fileInCanonicalDir = if (file.getParent() == null) { + file + } else { + new File(file.getParentFile().getCanonicalFile(), file.getName()) + } - if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) { - return false - } else { - return true + if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) { + return false + } else { + return true + } } } @@ -809,8 +813,8 @@ private[spark] object Utils extends Logging { */ def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() - .filterNot { ste:StackTraceElement => - // When running under some profilers, the current stack trace might contain some bogus + .filterNot { ste:StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus // frames. This is intended to ensure that we don't crash in these situations by // ignoring any frames that we can't examine. (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) @@ -934,6 +938,11 @@ private[spark] object Utils extends Logging { * * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing * an IllegalStateException. + * + * TODO: This will detect only if VM is shutting down, not when we are programmatically shutting + * down spark via stop()'s, like AppClient.markDead, etc. Unfortunately, the attempt (below) to + * fix this ran into issues with local mode and how test suites are run. + * So for now, some assertions and/or code paths which require latter to be detected will fail */ def inShutdown(): Boolean = { try { @@ -1024,7 +1033,7 @@ private[spark] object Utils extends Logging { def nonNegativeHash(obj: AnyRef): Int = { // Required ? - if (obj eq null) return 0 + if (null == obj) return 0 val hash = obj.hashCode // math.abs fails for Int.MinValue diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index be8f6529f7a1..4110459318f7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException} +import java.io.{BufferedInputStream, File, FileInputStream, EOFException, IOException, Serializable} import java.util.Comparator import scala.collection.BufferedIterator @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator @@ -161,16 +161,22 @@ class ExternalAppendOnlyMap[K, V, C]( // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] + var totalBytesWritten = 0L // Flush the disk writer's contents to disk, and update relevant variables def flush() = { - writer.commit() - val bytesWritten = writer.bytesWritten + val w = writer + writer = null + w.commitAndClose() + val bytesWritten = w.bytesWritten batchSizes.append(bytesWritten) + totalBytesWritten += bytesWritten + assert(file.length() == totalBytesWritten) _diskBytesSpilled += bytesWritten objectsWritten = 0 } + var success = false try { val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { @@ -180,16 +186,25 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer.close() writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) } } if (objectsWritten > 0) { flush() + } else if (null != writer) { + val w = writer + writer = null + w.revertPartialWritesAndClose() } + success = true } finally { - // Partial failures cannot be tolerated; do not revert partial writes - writer.close() + if (success) { + assert(null == writer) + assert(file.length() == totalBytesWritten) + } else { + if (null != writer) writer.revertPartialWritesAndClose() + if (file.exists()) file.delete() + } } currentMap = new SizeTrackingAppendOnlyMap[K, C] @@ -353,26 +368,53 @@ class ExternalAppendOnlyMap[K, V, C]( */ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) extends Iterator[(K, C)] { - private val fileStream = new FileInputStream(file) - private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + + assert(! batchSizes.isEmpty) + assert(! batchSizes.exists(_ <= 0)) + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) + assert(file.length() == batchOffsets(batchOffsets.length - 1)) + + private var batchIndex = 0 + private var fileStream: FileInputStream = null // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var batchStream = nextBatchStream() - private var compressedStream = blockManager.wrapForCompression(blockId, batchStream) - private var deserializeStream = ser.deserializeStream(compressedStream) + private var deserializeStream = nextBatchStream() private var nextItem: (K, C) = null private var objectsRead = 0 /** * Construct a stream that reads only from the next batch. */ - private def nextBatchStream(): InputStream = { - if (batchSizes.length > 0) { - ByteStreams.limit(bufferedStream, batchSizes.remove(0)) + private def nextBatchStream(): DeserializationStream = { + if (batchIndex + 1 < batchOffsets.length) { + assert(file.length() == batchOffsets(batchOffsets.length - 1)) + if (null != deserializeStream) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + assert(start == fileStream.getChannel.position()) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val strm = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + + val compressedStream = blockManager.wrapForCompression(blockId, strm) + val ser = serializer.newInstance() + ser.deserializeStream(compressedStream) } else { // No more batches left - bufferedStream + cleanup() + null } } @@ -387,10 +429,8 @@ class ExternalAppendOnlyMap[K, V, C]( val item = deserializeStream.readObject().asInstanceOf[(K, C)] objectsRead += 1 if (objectsRead == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(blockId, batchStream) - deserializeStream = ser.deserializeStream(compressedStream) objectsRead = 0 + deserializeStream = nextBatchStream() } item } catch { @@ -402,6 +442,7 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { + if (null == deserializeStream) return false nextItem = readNextItem() } nextItem != null @@ -418,7 +459,25 @@ class ExternalAppendOnlyMap[K, V, C]( // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { - deserializeStream.close() + batchIndex = batchOffsets.length + val dstrm = deserializeStream + val fstrm = fileStream + deserializeStream = null + fileStream = null + + if (null != dstrm) { + try { + dstrm.close() + } catch { + case ioEx: IOException => { + // best case attempt - atleast free the handles + if (null != fstrm) { + try { fstrm.close() } catch {case ioEx: IOException => } + } + throw ioEx + } + } + } file.delete() } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index aaa771404973..92a7956f4655 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -26,7 +26,10 @@ import com.google.common.io.Files import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.util.Utils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{AkkaUtils, Utils} +import akka.actor.Props +import org.apache.spark.scheduler.LiveListenerBus class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -121,6 +124,87 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before newFile.delete() } + private def checkSegments(segment1: FileSegment, segment2: FileSegment) { + assert(segment1.file.getCanonicalPath === segment2.file.getCanonicalPath) + assert(segment1.offset === segment2.offset) + assert(segment1.length === segment2.length) + } + + test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { + + val serializer = new JavaSerializer(testConf) + val confCopy = testConf.clone + // reset after EACH object write. This is to ensure that there are bytes appended after + // an object is written. So if the codepaths assume writeObject is end of data, this should + // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. + confCopy.set("spark.serializer.objectStreamReset", "1") + + val securityManager = new org.apache.spark.SecurityManager(confCopy) + // Do not use the shuffleBlockManager above ! + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy, + securityManager) + val master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), + confCopy) + val store = new BlockManager("", actorSystem, master , serializer, confCopy, + securityManager, null) + + try { + + val shuffleManager = store.shuffleBlockManager + + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + for (writer <- shuffle1.writers) { + writer.write("test1") + writer.write("test2") + } + for (writer <- shuffle1.writers) { + writer.commitAndClose() + } + + val shuffle1Segment = shuffle1.writers(0).fileSegment() + shuffle1.releaseWriters(success = true) + + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + + for (writer <- shuffle2.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle2.writers) { + writer.commitAndClose() + } + shuffle2.releaseWriters(success = true) + + // Now comes the test : + // Write to shuffle 3; and close it, but before registering it, check if the file lengths for + // previous task (for shuffle1) is the same as 'segments'. Earlier, we were inferring length + // of block based on remaining data in file : which could mess things up when there is concurrent read + // and writes happening to the same shuffle group. + + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + for (writer <- shuffle3.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle3.writers) { + writer.commitAndClose() + } + // check before we register. + checkSegments(shuffle1Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 1, 0))) + shuffle3.releaseWriters(success = true) + checkSegments(shuffle1Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 1, 0))) + shuffleManager.removeShuffle(1) + } finally { + + if (store != null) { + store.stop() + } + actorSystem.shutdown() + actorSystem.awaitTermination() + } + } + def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { val segment = diskBlockManager.getBlockLocation(blockId) assert(segment.file.getName === filename) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala new file mode 100644 index 000000000000..28c1132c9a91 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.FunSuite +import java.io.{IOException, FileOutputStream, OutputStream, File} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + +/** + * Test various code paths in DiskBlockObjectWriter + */ +class DiskBlockObjectWriterSuite extends FunSuite { + + private val conf = new SparkConf + private val BUFFER_SIZE = 32 * 1024 + + private def tempFile(): File = { + val file = File.createTempFile("temp_", "block") + // We dont want file to exist ! Just need a temp file name + file.delete() + file + } + + private def createWriter(file: File = tempFile()) : + (File, DiskBlockObjectWriter) = { + file.deleteOnExit() + + (file, new DiskBlockObjectWriter(BlockId("test_1"), file, + new JavaSerializer(conf), BUFFER_SIZE, (out: OutputStream) => out, true)) + } + + + test("write after close should throw IOException") { + val (file, bow) = createWriter() + bow.write("test") + bow.write("test1") + assert(file.exists() && file.isFile) + + bow.commitAndClose() + + intercept[IOException] { + bow.write("test2") + } + + file.delete() + } + + test("write after revert should throw IOException") { + val (file, bow) = createWriter() + bow.write("test") + bow.write("test1") + assert(file.exists() && file.isFile) + + bow.revertPartialWritesAndClose() + + intercept[IOException] { + bow.write("test2") + } + + file.delete() + } + + test("create even if directory does not exist") { + val dir = File.createTempFile("temp_", "dir") + dir.delete() + + val file = new File(dir, "temp.file") + file.deleteOnExit() + + val bow = new DiskBlockObjectWriter(BlockId("test_1"), file, new JavaSerializer(conf), + BUFFER_SIZE, (out: OutputStream) => out, true) + + bow.write("test") + assert(file.exists() && file.isFile) + bow.commitAndClose() + Utils.deleteRecursively(dir) + } + + test("revert of new file should delete it") { + val (file, bow) = createWriter() + bow.write("test") + bow.write("test1") + assert(file.exists() && file.isFile) + + bow.revertPartialWritesAndClose() + assert(! file.exists()) + // file.delete() + } + + test("revert of existing file should revert it to previous state") { + val (file, bow1) = createWriter() + + bow1.write("test") + bow1.write("test1") + assert(file.exists() && file.isFile) + + bow1.commitAndClose() + val length = file.length() + + // reopen same file. + val bow2 = createWriter(file)._2 + + bow2.write("test3") + bow2.write("test4") + + assert(file.exists() && file.isFile) + + bow2.revertPartialWritesAndClose() + assert(file.exists()) + assert(length == file.length()) + file.delete() + } + + test("revert of writer after close should delete if it did not exist earlier") { + val (file, bow) = createWriter(tempFile()) + + bow.write("test") + bow.write("test1") + assert(file.exists() && file.isFile) + + bow.commitAndClose() + val length = file.length() + + assert(file.exists() && file.isFile) + assert(length > 0) + + // Now revert the file, after it has been closed : should delete the file + // since it did not exist earlier. + bow.revertPartialWritesAndClose() + assert(! file.exists()) + file.delete() + } + + test("revert of writer after close should revert it to previous state") { + val (file, bow1) = createWriter() + + bow1.write("test") + bow1.write("test1") + assert(file.exists() && file.isFile) + + bow1.commitAndClose() + val length = file.length() + + // reopen same file. + val bow2 = createWriter(file)._2 + + bow2.write("test3") + bow2.write("test4") + + bow2.commitAndClose() + + assert(file.exists() && file.isFile) + assert(file.length() > length) + + // Now revert it : should get reverted back to previous state - after bow1 + bow2.revertPartialWritesAndClose() + assert(file.exists()) + assert(length == file.length()) + file.delete() + } + + test("If file changes before the open, throw exception instead of corrupting file") { + val (file, bow) = createWriter() + + // Now modify the file from under bow + val strm = new FileOutputStream(file, true) + strm.write(0) + strm.close() + + // try to write - must throw exception. + intercept[IOException] { + bow.write("test") + } + bow.revertPartialWritesAndClose() + // must not have deleted the file : since we never opened it. + assert(file.exists()) + file.delete() + } + + test("If exception is thrown while close, must throw exception") { + + val file = tempFile() + file.deleteOnExit() + + var disallowWrite = false + class TestOutputStream(delegate: OutputStream) extends OutputStream { + override def write(b: Int) { + if (disallowWrite) throw new IOException("disallowed by config") + delegate.write(b) + } + + override def flush() { + if (disallowWrite) throw new IOException("disallowed by config") + delegate.flush() + } + + override def close() { + delegate.close() + } + } + var customCompressor = (out: OutputStream) => new TestOutputStream(out) + + val bow = new DiskBlockObjectWriter(BlockId("test_1"), file, + new JavaSerializer(conf), BUFFER_SIZE, customCompressor, true) + + bow.write("test1") + bow.write("test2") + + var afterWrite = false + // try to write - must throw exception at close + intercept[IOException] { + bow.write("test3") + afterWrite = true + disallowWrite = true + // The buffer should mean we dont need to write + // bow.write("test4") + bow.commitAndClose() + } + + assert(afterWrite) + bow.revertPartialWritesAndClose() + assert(! file.exists()) + } + + test("If exception is thrown while revert, must still revert successfully") { + + val file = tempFile() + file.deleteOnExit() + + var disallowWrite = false + class TestOutputStream(delegate: OutputStream) extends OutputStream { + override def write(b: Int) { + if (disallowWrite) throw new IOException("disallowed by config") + delegate.write(b) + } + + override def flush() { + if (disallowWrite) throw new IOException("disallowed by config") + delegate.flush() + } + + override def close() { + delegate.close() + } + } + var customCompressor = (out: OutputStream) => new TestOutputStream(out) + + val bow1 = new DiskBlockObjectWriter(BlockId("test_1"), file, + new JavaSerializer(conf), BUFFER_SIZE, customCompressor, true) + + + bow1.write("test1") + bow1.write("test2") + + bow1.commitAndClose() + val length = file.length() + + val bow2 = new DiskBlockObjectWriter(BlockId("test_1"), file, + new JavaSerializer(conf), BUFFER_SIZE, customCompressor, true) + + bow2.write("test3") + bow2.write("test4") + // must cause exception to be raised, but should be handled gracefully and reverted. + // Note, it depends on current impl which does a flush within code. + disallowWrite = true + // The buffer should mean we dont need to write + // bow1.write("test4") + bow2.revertPartialWritesAndClose() + + assert(file.exists()) + assert(file.length() == length) + file.delete() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 428822949c08..369a96d5e63e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -30,8 +30,20 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + private def createSparkConf(loadDefaults: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + // So that we have suffix data written for every object + // This is to also test if we handle data written after object as part of + // spill properly (TC_RESET) + conf.set("spark.serializer.objectStreamReset", "0") + conf.set("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") + // Ensure that we actually have multiple blocks per spill + conf.set("spark.shuffle.spill.batchSize", "1") + conf + } + test("simple insert") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -57,7 +69,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("insert with collision") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -79,7 +91,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -124,7 +136,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("null keys and values") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -165,7 +177,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple aggregator") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) // reduceByKey @@ -180,7 +192,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple cogroup") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) @@ -198,10 +210,18 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + // Simple case - should spill ~16 times + val rddSimple = sc.parallelize(0 until 100000).map(i => (i, i)) + val resultSimple: Array[(Int, Int)] = rddSimple.reduceByKey(math.max).collect() + assert(resultSimple.length === 100000) + assert(resultSimple.map(_._1).toSet.size === 100000) + assert(resultSimple.map(_._1).min === 0) + assert(resultSimple.map(_._1).max === 99999) + // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) val resultA = rddA.reduceByKey(math.max).collect() @@ -250,7 +270,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -305,7 +325,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with many hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -330,7 +350,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -348,7 +368,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8e8c35615a71..1f490b76e0e9 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -62,9 +62,8 @@ object StoragePerfTester { writers(i % numOutputSplits).write(writeData) } writers.map {w => - w.commit() + w.commitAndClose() total.addAndGet(w.fileSegment().length) - w.close() } shuffle.releaseWriters(true)