diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt index 1e5f9b106..63c583b7b 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt @@ -195,7 +195,15 @@ fun JupyterConnection.Socket.controlMessagesHandler(msg: Message, repl: ReplForJ is ShutdownRequest -> { repl?.evalOnShutdown() send(makeReplyMessage(msg, MessageType.SHUTDOWN_REPLY, content = msg.content)) - exitProcess(0) + // exitProcess would kill the entire process that embedded the kernel + // Instead the controlThread will be interrupted, + // which will then interrupt the mainThread and make kernelServer return + if (repl?.isEmbedded == true) { + log.info("Interrupting controlThread to trigger kernel shutdown") + throw InterruptedException() + } else { + exitProcess(0) + } } } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt index e9626d02a..c27b107e8 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt @@ -132,6 +132,9 @@ interface ReplForJupyter { val notebook: NotebookImpl val fileExtension: String + + val isEmbedded: Boolean + get() = false } fun ReplForJupyter.execute(callback: ExecutionCallback): T { @@ -145,7 +148,7 @@ class ReplForJupyterImpl( override val resolverConfig: ResolverConfig? = null, override val runtimeProperties: ReplRuntimeProperties = defaultRuntimeProperties, private val scriptReceivers: List = emptyList(), - private val embedded: Boolean = false, + override val isEmbedded: Boolean = false, ) : ReplForJupyter, ReplOptions, BaseKernelHost, KotlinKernelHostProvider { constructor( @@ -262,7 +265,7 @@ class ReplForJupyterImpl( private val evaluatorConfiguration = ScriptEvaluationConfiguration { implicitReceivers.invoke(v = scriptReceivers) - if (!embedded) { + if (!isEmbedded) { jvm { val filteringClassLoader = FilteringClassLoader(ClassLoader.getSystemClassLoader()) { fqn -> listOf( diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/embeddingTest.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/embeddingTest.kt index 8965d9d0a..d3ec6b57d 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/embeddingTest.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/embeddingTest.kt @@ -90,7 +90,7 @@ val testLibraryDefinition2 = LibraryDefinitionImpl( class EmbedReplTest : AbstractReplTest() { private val repl = run { val embeddedClasspath: List = System.getProperty("java.class.path").split(File.pathSeparator).map(::File) - ReplForJupyterImpl(resolutionInfoProvider, embeddedClasspath, embedded = true) + ReplForJupyterImpl(resolutionInfoProvider, embeddedClasspath, isEmbedded = true) } @Test