Skip to content

Commit cb358f6

Browse files
committed
Add thread-based interruption implementation
1 parent 8279b00 commit cb358f6

File tree

2 files changed

+74
-21
lines changed

2 files changed

+74
-21
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt

+45
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ import kotlinx.serialization.json.JsonObject
99
import kotlinx.serialization.json.decodeFromJsonElement
1010
import kotlinx.serialization.json.encodeToJsonElement
1111
import kotlinx.serialization.json.jsonObject
12+
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
1213
import org.zeromq.SocketType
1314
import org.zeromq.ZMQ
1415
import java.io.Closeable
1516
import java.io.IOException
1617
import java.security.SignatureException
1718
import javax.crypto.Mac
1819
import javax.crypto.spec.SecretKeySpec
20+
import kotlin.concurrent.thread
1921
import kotlin.math.min
2022

2123
class JupyterConnection(val config: KernelConfig) : Closeable {
@@ -144,6 +146,49 @@ class JupyterConnection(val config: KernelConfig) : Closeable {
144146

145147
var contextMessage: Message? = null
146148

149+
private val currentExecutions = HashSet<Thread>()
150+
151+
data class ConnectionExecutionResult<T>(
152+
val result: T?,
153+
val throwable: Throwable?,
154+
val isInterrupted: Boolean,
155+
)
156+
157+
fun <T> runExecution(body: () -> T): ConnectionExecutionResult<T> {
158+
var execRes: T? = null
159+
var execException: Throwable? = null
160+
val execThread = thread {
161+
try {
162+
execRes = body()
163+
} catch (e: Throwable) {
164+
execException = e
165+
}
166+
}
167+
currentExecutions.add(execThread)
168+
execThread.join()
169+
currentExecutions.remove(execThread)
170+
171+
val exception = execException
172+
val isInterrupted = exception is ThreadDeath ||
173+
(exception is ReplException && exception.cause is ThreadDeath)
174+
return ConnectionExecutionResult(execRes, exception, isInterrupted)
175+
}
176+
177+
/**
178+
* We cannot use [Thread.interrupt] here because we have no way
179+
* to control the code user executes. [Thread.interrupt] will do nothing for
180+
* the simple calculation (like `while (true) 1`). Consider replacing with
181+
* something more smart in the future.
182+
*/
183+
fun interruptExecution() {
184+
@Suppress("deprecation")
185+
while (currentExecutions.isNotEmpty()) {
186+
val execution = currentExecutions.firstOrNull()
187+
execution?.stop()
188+
currentExecutions.remove(execution)
189+
}
190+
}
191+
147192
override fun close() {
148193
heartbeat.close()
149194
shell.close()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

+29-21
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class ErrorResponseWithMessage(
192192
fun JupyterConnection.Socket.controlMessagesHandler(msg: Message, repl: ReplForJupyter?) {
193193
when (msg.content) {
194194
is InterruptRequest -> {
195-
log.warn("Interruption is not yet supported!")
195+
connection.interruptExecution()
196196
send(makeReplyMessage(msg, MessageType.INTERRUPT_REPLY, content = msg.content))
197197
}
198198
is ShutdownRequest -> {
@@ -439,6 +439,12 @@ fun JupyterConnection.evalWithIO(repl: ReplForJupyter, srcMessage: Message, body
439439
val forkedError = getCapturingStream(err, JupyterOutType.STDERR, false)
440440
val userError = getCapturingStream(null, JupyterOutType.STDERR, true)
441441

442+
fun flushStreams() {
443+
forkedOut.flush()
444+
forkedError.flush()
445+
userError.flush()
446+
}
447+
442448
val printForkedOut = PrintStream(forkedOut, false, "UTF-8")
443449
val printForkedErr = PrintStream(forkedError, false, "UTF-8")
444450
val printUserError = PrintStream(userError, false, "UTF-8")
@@ -453,26 +459,30 @@ fun JupyterConnection.evalWithIO(repl: ReplForJupyter, srcMessage: Message, body
453459
System.setIn(if (allowStdIn) stdinIn else DisabledStdinInputStream)
454460
try {
455461
return try {
456-
val exec = body()
457-
if (exec == null) {
458-
AbortResponseWithMessage("NO REPL!")
459-
} else {
460-
forkedOut.flush()
461-
forkedError.flush()
462-
userError.flush()
463-
464-
try {
465-
val result = exec.resultValue?.toDisplayResult(repl.notebook)
466-
OkResponseWithMessage(result, exec.metadata)
467-
} catch (e: Exception) {
468-
AbortResponseWithMessage("error: Unable to convert result to a string: $e")
462+
val (exec, execException, executionInterrupted) = runExecution(body)
463+
when {
464+
executionInterrupted -> {
465+
flushStreams()
466+
AbortResponseWithMessage("The execution was interrupted")
467+
}
468+
execException != null -> {
469+
throw execException
470+
}
471+
exec == null -> {
472+
AbortResponseWithMessage("NO REPL!")
473+
}
474+
else -> {
475+
flushStreams()
476+
try {
477+
val result = exec.resultValue?.toDisplayResult(repl.notebook)
478+
OkResponseWithMessage(result, exec.metadata)
479+
} catch (e: Exception) {
480+
AbortResponseWithMessage("error: Unable to convert result to a string: $e")
481+
}
469482
}
470483
}
471484
} catch (ex: ReplException) {
472-
forkedOut.flush()
473-
forkedError.flush()
474-
userError.flush()
475-
485+
flushStreams()
476486
ErrorResponseWithMessage(
477487
ex.render(),
478488
ex.javaClass.canonicalName,
@@ -482,9 +492,7 @@ fun JupyterConnection.evalWithIO(repl: ReplForJupyter, srcMessage: Message, body
482492
)
483493
}
484494
} finally {
485-
forkedOut.close()
486-
forkedError.close()
487-
userError.close()
495+
flushStreams()
488496
System.setIn(`in`)
489497
System.setErr(err)
490498
System.setOut(out)

0 commit comments

Comments
 (0)