Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicInteger

import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}

import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
Expand Down Expand Up @@ -72,6 +73,8 @@ private[spark] class ConnectionManager(

// default to 30 second timeout waiting for authentication
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val timeoutMs = conf.getInt("spark.core.connection.timeoutMs", 60000)
private val timer = new HashedWheelTimer(10, TimeUnit.MILLISECONDS)

private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
Expand Down Expand Up @@ -172,7 +175,7 @@ private[spark] class ConnectionManager(
}
}
}
} )
})
}

private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
Expand All @@ -184,7 +187,7 @@ private[spark] class ConnectionManager(
readRunnableStarted.synchronized {
// So that we do not trigger more read events while processing this one.
// The read method will re-register when done.
if (conn.changeInterestForRead())conn.unregisterInterest()
if (conn.changeInterestForRead()) conn.unregisterInterest()
if (readRunnableStarted.contains(key)) {
return
}
Expand All @@ -205,7 +208,7 @@ private[spark] class ConnectionManager(
}
}
}
} )
})
}

private def triggerConnect(key: SelectionKey) {
Expand Down Expand Up @@ -233,7 +236,7 @@ private[spark] class ConnectionManager(
// not succeed : hence the loop to retry a few 'times'.
conn.finishConnect(true)
}
} )
})
}

// MUST be called within selector loop - else deadlock.
Expand Down Expand Up @@ -270,15 +273,15 @@ private[spark] class ConnectionManager(

def run() {
try {
while(!selectorThread.isInterrupted) {
while (!selectorThread.isInterrupted) {
while (!registerRequests.isEmpty) {
val conn: SendingConnection = registerRequests.dequeue()
addListeners(conn)
conn.connect()
addConnection(conn)
}

while(!keyInterestChangeRequests.isEmpty) {
while (!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue()

try {
Expand All @@ -300,7 +303,7 @@ private[spark] class ConnectionManager(
}

logTrace("Changed key for connection to [" +
connection.getRemoteConnectionManagerId() + "] changed from [" +
connection.getRemoteConnectionManagerId() + "] changed from [" +
intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
}
Expand Down Expand Up @@ -333,7 +336,7 @@ private[spark] class ConnectionManager(
while (allKeys.hasNext) {
val key = allKeys.next()
try {
if (! key.isValid) {
if (!key.isValid) {
logInfo("Key not valid ? " + key)
throw new CancelledKeyException()
}
Expand Down Expand Up @@ -536,7 +539,7 @@ private[spark] class ConnectionManager(
}
return
} else {
var replyToken : Array[Byte] = null
var replyToken: Array[Byte] = null
try {
replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
Expand Down Expand Up @@ -568,7 +571,7 @@ private[spark] class ConnectionManager(
connectionId: ConnectionId) {
if (!connection.isSaslComplete()) {
logDebug("saslContext not established")
var replyToken : Array[Byte] = null
var replyToken: Array[Byte] = null
try {
connection.synchronized {
if (connection.sparkSaslServer == null) {
Expand Down Expand Up @@ -614,7 +617,7 @@ private[spark] class ConnectionManager(
connectionsAwaitingSasl.get(connectionId) match {
case Some(waitingConn) => {
// Client - this must be in response to us doing Send
logDebug("Client handleAuth for id: " + waitingConn.connectionId)
logDebug("Client handleAuth for id: " + waitingConn.connectionId)
handleClientAuthentication(waitingConn, securityMsg, connectionId)
}
case None => {
Expand Down Expand Up @@ -777,7 +780,7 @@ private[spark] class ConnectionManager(
}
message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)
"connectionid: " + connection.connectionId)

if (authEnabled) {
// if we aren't authenticated yet lets block the senders until authentication completes
Expand Down Expand Up @@ -827,6 +830,30 @@ private[spark] class ConnectionManager(
selector.wakeup()
}

private def timeoutMessageStatus[T](msg: Message, future: Future[T]): Future[T] = {
val after = timeoutMs.milliseconds
val timerTask = new TimerTask {
def run(timeout: Timeout) {
messageStatuses.synchronized {
messageStatuses.get(msg.id) match {
case Some(status) => {
messageStatuses -= msg.id
logError("Time out while Sending [" + msg + "] to [" +
status.connectionManagerId + "]")
status.markDone(None)
}
case None => {
// TODO: ?
}
}
}
}
}
val timeout = timer.newTimeout(timerTask, after.toNanos, TimeUnit.NANOSECONDS)
future.onComplete { case result => timeout.cancel()}
future
}

/**
* Send a message and block until an acknowldgment is received or an error occurs.
* @param connectionManagerId the message's destination
Expand All @@ -853,14 +880,15 @@ private[spark] class ConnectionManager(
messageStatuses += ((message.id, status))
}
sendMessage(connectionManagerId, message)
promise.future
timeoutMessageStatus(message, promise.future)
}

def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
onReceiveCallback = callback
}

def stop() {
timer.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ import scala.util.Try
*/
class ConnectionManagerSuite extends FunSuite {

test("receiver test with timeout") {
val conf = new SparkConf
conf.set("spark.core.connection.timeoutMs", "100")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
var receivedMessage = false
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
receivedMessage = true
Thread.sleep(1000)
val buffer = ByteBuffer.wrap("response".getBytes("utf-8"))
Some(Message.createBufferMessage(buffer, msg.id))
})

val future = manager.sendMessageReliably(manager.id, Message.createBufferMessage(
ByteBuffer.wrap("request".getBytes("utf-8"))))

intercept[IOException] {
Await.result(future, 3 second)
}
assert(receivedMessage == true)
manager.stop()
}

test("security default off") {
val conf = new SparkConf
val securityManager = new SecurityManager(conf)
Expand Down