diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f198aa8564a54..e003a30c48d81 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -23,7 +23,7 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{Executors, LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} @@ -77,7 +77,8 @@ private[nio] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = Executors.newScheduledThreadPool(2, + Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -899,10 +900,12 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { + // to avoid reference to the whole message body + val messageId: Int = message.id + val timeoutTask = new Runnable { override def run(): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { + messageStatuses.remove(messageId).foreach ( s => { val e = new IOException("sendMessageReliably failed because ack " + s"was not received within $ackTimeout sec") if (!promise.tryFailure(e)) { @@ -913,8 +916,9 @@ private[nio] class ConnectionManager( } } + val timeoutTaskFuture = ackTimeoutMonitor.schedule(timeoutTask, ackTimeout, TimeUnit.SECONDS) val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskFuture.cancel(true) s match { case scala.util.Failure(e) => // Indicates a failure where we either never sent or never got ACK'd @@ -943,7 +947,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -953,7 +956,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.shutdownNow() selectorThread.interrupt() selectorThread.join() selector.close()