From e707fb45adcf78e1322e8fa496c25fe9695d4e91 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 7 Aug 2014 11:05:04 +0800 Subject: [PATCH] BasicBlockFetchIterator#next can wait forever --- .../spark/network/ConnectionManager.scala | 54 ++++++++++++++----- .../network/ConnectionManagerSuite.scala | 23 ++++++++ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 95f96b8463a01..a003897b6d337 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -72,6 +73,8 @@ private[spark] class ConnectionManager( // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val timeoutMs = conf.getInt("spark.core.connection.timeoutMs", 60000) + private val timer = new HashedWheelTimer(10, TimeUnit.MILLISECONDS) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -172,7 +175,7 @@ private[spark] class ConnectionManager( } } } - } ) + }) } private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() @@ -184,7 +187,7 @@ private[spark] class ConnectionManager( readRunnableStarted.synchronized { // So that we do not trigger more read events while processing this one. // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() + if (conn.changeInterestForRead()) conn.unregisterInterest() if (readRunnableStarted.contains(key)) { return } @@ -205,7 +208,7 @@ private[spark] class ConnectionManager( } } } - } ) + }) } private def triggerConnect(key: SelectionKey) { @@ -233,7 +236,7 @@ private[spark] class ConnectionManager( // not succeed : hence the loop to retry a few 'times'. conn.finishConnect(true) } - } ) + }) } // MUST be called within selector loop - else deadlock. @@ -270,7 +273,7 @@ private[spark] class ConnectionManager( def run() { try { - while(!selectorThread.isInterrupted) { + while (!selectorThread.isInterrupted) { while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) @@ -278,7 +281,7 @@ private[spark] class ConnectionManager( addConnection(conn) } - while(!keyInterestChangeRequests.isEmpty) { + while (!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue() try { @@ -300,7 +303,7 @@ private[spark] class ConnectionManager( } logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + + connection.getRemoteConnectionManagerId() + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") } } @@ -333,7 +336,7 @@ private[spark] class ConnectionManager( while (allKeys.hasNext) { val key = allKeys.next() try { - if (! key.isValid) { + if (!key.isValid) { logInfo("Key not valid ? " + key) throw new CancelledKeyException() } @@ -536,7 +539,7 @@ private[spark] class ConnectionManager( } return } else { - var replyToken : Array[Byte] = null + var replyToken: Array[Byte] = null try { replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) if (waitingConn.isSaslComplete()) { @@ -568,7 +571,7 @@ private[spark] class ConnectionManager( connectionId: ConnectionId) { if (!connection.isSaslComplete()) { logDebug("saslContext not established") - var replyToken : Array[Byte] = null + var replyToken: Array[Byte] = null try { connection.synchronized { if (connection.sparkSaslServer == null) { @@ -614,7 +617,7 @@ private[spark] class ConnectionManager( connectionsAwaitingSasl.get(connectionId) match { case Some(waitingConn) => { // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) + logDebug("Client handleAuth for id: " + waitingConn.connectionId) handleClientAuthentication(waitingConn, securityMsg, connectionId) } case None => { @@ -777,7 +780,7 @@ private[spark] class ConnectionManager( } message.senderAddress = id.toSocketAddress() logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) + "connectionid: " + connection.connectionId) if (authEnabled) { // if we aren't authenticated yet lets block the senders until authentication completes @@ -827,6 +830,30 @@ private[spark] class ConnectionManager( selector.wakeup() } + private def timeoutMessageStatus[T](msg: Message, future: Future[T]): Future[T] = { + val after = timeoutMs.milliseconds + val timerTask = new TimerTask { + def run(timeout: Timeout) { + messageStatuses.synchronized { + messageStatuses.get(msg.id) match { + case Some(status) => { + messageStatuses -= msg.id + logError("Time out while Sending [" + msg + "] to [" + + status.connectionManagerId + "]") + status.markDone(None) + } + case None => { + // TODO: ? + } + } + } + } + } + val timeout = timer.newTimeout(timerTask, after.toNanos, TimeUnit.NANOSECONDS) + future.onComplete { case result => timeout.cancel()} + future + } + /** * Send a message and block until an acknowldgment is received or an error occurs. * @param connectionManagerId the message's destination @@ -853,7 +880,7 @@ private[spark] class ConnectionManager( messageStatuses += ((message.id, status)) } sendMessage(connectionManagerId, message) - promise.future + timeoutMessageStatus(message, promise.future) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { @@ -861,6 +888,7 @@ private[spark] class ConnectionManager( } def stop() { + timer.stop() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 846537df003df..1d12ce5f82911 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -33,6 +33,29 @@ import scala.util.Try */ class ConnectionManagerSuite extends FunSuite { + test("receiver test with timeout") { + val conf = new SparkConf + conf.set("spark.core.connection.timeoutMs", "100") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var receivedMessage = false + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + receivedMessage = true + Thread.sleep(1000) + val buffer = ByteBuffer.wrap("response".getBytes("utf-8")) + Some(Message.createBufferMessage(buffer, msg.id)) + }) + + val future = manager.sendMessageReliably(manager.id, Message.createBufferMessage( + ByteBuffer.wrap("request".getBytes("utf-8")))) + + intercept[IOException] { + Await.result(future, 3 second) + } + assert(receivedMessage == true) + manager.stop() + } + test("security default off") { val conf = new SparkConf val securityManager = new SecurityManager(conf)