Skip to content

Commit 154cdbc

Browse files
committed
Add on-shutdown codes for library descriptors
Fixes #87
1 parent 9e6cbfd commit 154cdbc

File tree

7 files changed

+80
-10
lines changed

7 files changed

+80
-10
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ Library descriptor is a `<libName>.json` file with the following fields:
196196
- `imports`: a list of default imports for library
197197
- `init`: a list of code snippets to be executed when library is included
198198
- `initCell`: a list of code snippets to be executed before execution of any cell
199+
- `shutdown`: a list of code snippets to be executed on kernel shutdown. Any cleanup code goes here
199200
- `renderers`: a list of type converters for special rendering of particular types
200201

201202
*All fields are optional

src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt

+4-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ open class LibraryDefinition(
129129
val imports: List<String>,
130130
val repositories: List<String>,
131131
val init: List<String>,
132+
val shutdown: List<String>,
132133
val renderers: List<TypeHandler>,
133134
val converters: List<TypeHandler>,
134135
val annotations: List<TypeHandler>
@@ -140,10 +141,11 @@ class LibraryDescriptor(dependencies: List<String>,
140141
imports: List<String>,
141142
repositories: List<String>,
142143
init: List<String>,
144+
shutdown: List<String>,
143145
renderers: List<TypeHandler>,
144146
converters: List<TypeHandler>,
145147
annotations: List<TypeHandler>,
146-
val link: String?) : LibraryDefinition(dependencies, initCell, imports, repositories, init, renderers, converters, annotations)
148+
val link: String?) : LibraryDefinition(dependencies, initCell, imports, repositories, init, shutdown, renderers, converters, annotations)
147149

148150
data class ResolverConfig(val repositories: List<RepositoryCoordinates>,
149151
val libraries: Deferred<Map<String, LibraryDescriptor>>)
@@ -339,6 +341,7 @@ fun parserLibraryDescriptors(libJsons: Map<String, JsonObject>): Map<String, Lib
339341
imports = it.value.array<String>("imports")?.toList().orEmpty(),
340342
repositories = it.value.array<String>("repositories")?.toList().orEmpty(),
341343
init = it.value.array<String>("init")?.toList().orEmpty(),
344+
shutdown = it.value.array<String>("shutdown")?.toList().orEmpty(),
342345
initCell = it.value.array<String>("initCell")?.toList().orEmpty(),
343346
renderers = it.value.obj("renderers")?.map {
344347
TypeHandler(it.key, it.value.toString())

src/main/kotlin/org/jetbrains/kotlin/jupyter/libraries.kt

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class LibrariesProcessor(private val libraries: Deferred<Map<String, LibraryDesc
3838
repositories = library.repositories.map { replaceVariables(it, mapping) },
3939
imports = library.imports.map { replaceVariables(it, mapping) },
4040
init = library.init.map { replaceVariables(it, mapping) },
41+
shutdown = library.shutdown.map { replaceVariables(it, mapping) },
4142
initCell = library.initCell.map { replaceVariables(it, mapping) },
4243
renderers = library.renderers.map { TypeHandler(it.className, replaceVariables(it.code, mapping)) },
4344
converters = library.converters.map { TypeHandler(it.className, replaceVariables(it.code, mapping)) },

src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
9999
send(makeReplyMessage(msg, "interrupt_reply", content = msg.content))
100100
}
101101
"shutdown_request" -> {
102+
repl?.evalOnShutdown()
102103
send(makeReplyMessage(msg, "shutdown_reply", content = msg.content))
103104
exitProcess(0)
104105
}

src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt

+31-2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ typealias Code = String
6363
interface ReplForJupyter {
6464
fun eval(code: Code, displayHandler: ((Any) -> Unit)? = null, jupyterId: Int = -1): EvalResult
6565

66+
fun evalOnShutdown(): List<EvalResult>
67+
6668
fun checkComplete(code: Code): CheckResult
6769

6870
suspend fun complete(code: String, cursor: Int, callback: (CompletionResult) -> Unit)
@@ -110,6 +112,7 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
110112
private val typeRenderers = mutableMapOf<String, String>()
111113

112114
private val initCellCodes = mutableListOf<String>()
115+
private val shutdownCodes = mutableListOf<String>()
113116

114117
private fun renderResult(value: Any?, resultField: Pair<String, KotlinType>?): Any? {
115118
if (value == null || resultField == null) return null
@@ -119,13 +122,20 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
119122
return renderResult(result.value, result.resultField)
120123
}
121124

122-
data class PreprocessingResult(val code: Code, val initCodes: List<Code>, val initCellCodes: List<Code>, val typeRenderers: List<TypeHandler>)
125+
data class PreprocessingResult(
126+
val code: Code,
127+
val initCodes: List<Code>,
128+
val shutdownCodes: List<Code>,
129+
val initCellCodes: List<Code>,
130+
val typeRenderers: List<TypeHandler>,
131+
)
123132

124133
fun preprocessCode(code: String): PreprocessingResult {
125134

126135
val processedMagics = magics.processMagics(code)
127136

128137
val initCodes = mutableListOf<Code>()
138+
val shutdownCodes = mutableListOf<Code>()
129139
val initCellCodes = mutableListOf<Code>()
130140
val typeRenderers = mutableListOf<TypeHandler>()
131141
val typeConverters = mutableListOf<TypeHandler>()
@@ -141,13 +151,16 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
141151
typeRenderers.addAll(libraryDefinition.renderers)
142152
typeConverters.addAll(libraryDefinition.converters)
143153
annotations.addAll(libraryDefinition.annotations)
154+
initCellCodes.addAll(libraryDefinition.initCell)
155+
shutdownCodes.addAll(libraryDefinition.shutdown)
144156
libraryDefinition.init.forEach {
145157

146158
// Library init code may contain other magics, so we process them recursively
147159
val preprocessed = preprocessCode(it)
148160
initCodes.addAll(preprocessed.initCodes)
149161
typeRenderers.addAll(preprocessed.typeRenderers)
150162
initCellCodes.addAll(preprocessed.initCellCodes)
163+
shutdownCodes.addAll(preprocessed.shutdownCodes)
151164
if (preprocessed.code.isNotBlank())
152165
initCodes.add(preprocessed.code)
153166
}
@@ -159,7 +172,7 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
159172
initCodes.add(declarations)
160173
}
161174

162-
return PreprocessingResult(processedMagics.code, initCodes, initCellCodes, typeRenderers)
175+
return PreprocessingResult(processedMagics.code, initCodes, shutdownCodes, initCellCodes, typeRenderers)
163176
}
164177

165178
private val ctx = KotlinContext()
@@ -327,6 +340,7 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
327340

328341
private fun registerNewLibraries(p: PreprocessingResult) {
329342
p.initCellCodes.filter { !initCellCodes.contains(it) }.let(initCellCodes::addAll)
343+
p.shutdownCodes.filter { !shutdownCodes.contains(it) }.let(shutdownCodes::addAll)
330344
typeRenderers.putAll(p.typeRenderers.map { it.className to it.code })
331345
}
332346

@@ -395,6 +409,10 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
395409
}
396410
}
397411

412+
override fun evalOnShutdown(): List<EvalResult> {
413+
return shutdownCodes.map(::evalWithReturn)
414+
}
415+
398416
private fun updateOutputList(jupyterId: Int, result: Any?) {
399417
if (jupyterId >= 0) {
400418
while (ReplOutputs.count() <= jupyterId) ReplOutputs.add(null)
@@ -457,6 +475,17 @@ class ReplForJupyterImpl(private val scriptClasspath: List<File> = emptyList(),
457475
processAnnotations(lastReplLine())
458476
}
459477

478+
// Result of this function is considered to be used for testing/debug purposes
479+
private fun evalWithReturn(code: String): EvalResult {
480+
val result = try {
481+
doEval(code)
482+
} catch (e: Exception) {
483+
InternalEvalResult(null, null)
484+
}
485+
processAnnotations(lastReplLine())
486+
return EvalResult(result.value)
487+
}
488+
460489
private data class InternalEvalResult(val value: Any?, val resultField: Pair<String, KotlinType>?)
461490

462491
private interface LockQueueArgs <T> {

src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt

+32-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package org.jetbrains.kotlin.jupyter.test
22

3-
import com.beust.klaxon.JsonObject
4-
import com.beust.klaxon.Parser
53
import jupyter.kotlin.JavaRuntime
64
import jupyter.kotlin.MimeTypedResult
75
import jupyter.kotlin.receivers.ConstReceiver
@@ -11,10 +9,8 @@ import org.jetbrains.kotlin.jupyter.ReplCompilerException
119
import org.jetbrains.kotlin.jupyter.ReplEvalRuntimeException
1210
import org.jetbrains.kotlin.jupyter.ReplForJupyterImpl
1311
import org.jetbrains.kotlin.jupyter.ResolverConfig
14-
import org.jetbrains.kotlin.jupyter.asAsync
1512
import org.jetbrains.kotlin.jupyter.defaultRepositories
1613
import org.jetbrains.kotlin.jupyter.generateDiagnostic
17-
import org.jetbrains.kotlin.jupyter.parserLibraryDescriptors
1814
import org.jetbrains.kotlin.jupyter.repl.completion.CompletionResult
1915
import org.jetbrains.kotlin.jupyter.repl.completion.ListErrorsResult
2016
import org.jetbrains.kotlin.jupyter.withPath
@@ -363,11 +359,10 @@ class ReplTest : AbstractReplTest() {
363359
]
364360
}
365361
""".trimIndent()
366-
val parser = Parser.default()
367362

368-
val libJsons = arrayOf(lib1, lib2, lib3).map { it.first to parser.parse(StringBuilder(it.second)) as JsonObject }.toMap()
363+
val libs = listOf(lib1, lib2, lib3).toLibrariesAsync()
369364

370-
val replWithResolver = ReplForJupyterImpl(classpath, ResolverConfig(defaultRepositories, parserLibraryDescriptors(libJsons).asAsync()))
365+
val replWithResolver = ReplForJupyterImpl(classpath, ResolverConfig(defaultRepositories, libs))
371366
val res = replWithResolver.preprocessCode("%use mylib(1.0), another")
372367
assertEquals("", res.code)
373368
val inits = arrayOf(
@@ -397,6 +392,36 @@ class ReplTest : AbstractReplTest() {
397392
}
398393
}
399394

395+
@Test
396+
fun testLibraryOnShutdown() {
397+
val lib1 = "mylib" to """
398+
{
399+
"shutdown": [
400+
"14 * 3",
401+
"throw RuntimeException()",
402+
"21 + 22"
403+
]
404+
}""".trimIndent()
405+
406+
val lib2 = "mylib2" to """
407+
{
408+
"shutdown": [
409+
"100"
410+
]
411+
}""".trimIndent()
412+
413+
val libs = listOf(lib1, lib2).toLibrariesAsync()
414+
val replWithResolver = ReplForJupyterImpl(classpath, ResolverConfig(defaultRepositories, libs))
415+
replWithResolver.eval("%use mylib, mylib2")
416+
val results = replWithResolver.evalOnShutdown()
417+
418+
assertEquals(4, results.size)
419+
assertEquals(42, results[0].resultValue)
420+
assertNull(results[1].resultValue)
421+
assertEquals(43, results[2].resultValue)
422+
assertEquals(100, results[3].resultValue)
423+
}
424+
400425
@Test
401426
fun testJavaRuntimeUtils() {
402427
val result = repl.eval("JavaRuntimeUtils.version")

src/test/kotlin/org/jetbrains/kotlin/jupyter/test/testUtil.kt

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package org.jetbrains.kotlin.jupyter.test
22

3+
import com.beust.klaxon.JsonObject
4+
import com.beust.klaxon.Parser
35
import jupyter.kotlin.DependsOn
6+
import kotlinx.coroutines.Deferred
7+
import org.jetbrains.kotlin.jupyter.LibraryDescriptor
48
import org.jetbrains.kotlin.jupyter.ResolverConfig
59
import org.jetbrains.kotlin.jupyter.asAsync
610
import org.jetbrains.kotlin.jupyter.defaultRepositories
@@ -18,3 +22,9 @@ val classpath = scriptCompilationClasspathFromContext(
1822

1923
val testResolverConfig = ResolverConfig(defaultRepositories,
2024
parserLibraryDescriptors(readLibraries().toMap()).asAsync())
25+
26+
fun Collection<Pair<String, String>>.toLibrariesAsync(): Deferred<Map<String, LibraryDescriptor>> {
27+
val parser = Parser.default()
28+
val libJsons = map { it.first to parser.parse(StringBuilder(it.second)) as JsonObject }.toMap()
29+
return parserLibraryDescriptors(libJsons).asAsync()
30+
}

0 commit comments

Comments
 (0)