Skip to content

Commit 0f89932

Browse files
committed
[#32] Improved output capturing and added tests
1 parent b9dadc8 commit 0f89932

File tree

3 files changed

+179
-33
lines changed

3 files changed

+179
-33
lines changed

src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ data class KernelConfig(
3636
val signatureKey: String,
3737
val pollingIntervalMillis: Long = 100,
3838
val scriptClasspath: List<File> = emptyList(),
39-
val resolverConfig: ResolverConfig?
39+
val resolverConfig: ResolverConfig?,
40+
41+
val captureOutput: Boolean = true,
42+
val captureBufferTimeLimitMs: Int = 100,
43+
val captureBufferMaxSize: Int = 1000,
44+
val cellOutputMaxSize: Int = 100000
4045
)
4146

4247
val protocolVersion = "5.3"

src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt

+83-32
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,25 @@ enum class ResponseState {
1515
Ok, Error
1616
}
1717

18+
enum class JupyterOutType {
19+
STDOUT, STDERR;
20+
fun optionName() = name.toLowerCase()
21+
}
22+
1823
data class ResponseWithMessage(val state: ResponseState, val result: MimeTypedResult?, val displays: List<MimeTypedResult> = emptyList(), val stdOut: String? = null, val stdErr: String? = null) {
1924
val hasStdOut: Boolean = stdOut != null && stdOut.isNotEmpty()
2025
val hasStdErr: Boolean = stdErr != null && stdErr.isNotEmpty()
2126
}
2227

28+
fun JupyterConnection.Socket.sendOut(msg:Message, stream: JupyterOutType, text: String) {
29+
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
30+
content = jsonObject(
31+
"name" to stream.optionName(),
32+
"text" to text)))
33+
}
34+
2335
fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJupyter?, executionCount: AtomicLong) {
24-
val msgType = msg.header!!["msg_type"]
25-
when (msgType) {
36+
when (msg.header!!["msg_type"]) {
2637
"kernel_info_request" ->
2738
sendWrapped(msg, makeReplyMessage(msg, "kernel_info_reply",
2839
content = jsonObject(
@@ -72,18 +83,11 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
7283
}
7384
}
7485

75-
fun sendOut(stream: String, text: String) {
76-
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
77-
content = jsonObject(
78-
"name" to stream,
79-
"text" to text)))
80-
}
81-
8286
if (res.hasStdOut) {
83-
sendOut("stdout", res.stdOut!!)
87+
sendOut(msg, JupyterOutType.STDOUT, res.stdOut!!)
8488
}
8589
if (res.hasStdErr) {
86-
sendOut("stderr", res.stdErr!!)
90+
sendOut(msg, JupyterOutType.STDERR, res.stdErr!!)
8791
}
8892

8993
when (res.state) {
@@ -166,13 +170,54 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
166170
}
167171
}
168172

169-
class CapturingOutputStream(val stdout: PrintStream, val captureOutput: Boolean) : OutputStream() {
173+
class CapturingOutputStream(private val stdout: PrintStream,
174+
private val maxOutputSize: Int,
175+
private val captureOutput: Boolean,
176+
private val maxBufferSize: Int,
177+
private val maxBufferLifeTimeMs: Int,
178+
val onCaptured: (String) -> Unit) : OutputStream() {
170179
val capturedOutput = ByteArrayOutputStream()
180+
private var time = System.currentTimeMillis()
181+
private var overallOutputSize = 0
182+
183+
private fun shouldSend(b: Int): Boolean {
184+
val c = b.toChar()
185+
if (c == '\n' || c == '\r')
186+
return true
187+
if (capturedOutput.size() >= maxBufferSize)
188+
return true
189+
190+
val currentTime = System.currentTimeMillis()
191+
if (currentTime - time >= maxBufferLifeTimeMs) {
192+
time = currentTime
193+
return true
194+
}
195+
return false
196+
}
171197

172198
override fun write(b: Int) {
199+
if (++overallOutputSize > maxOutputSize) {
200+
throw OutputLimitExceededException()
201+
}
202+
173203
stdout.write(b)
174-
if (captureOutput) capturedOutput.write(b)
204+
205+
if (captureOutput) {
206+
capturedOutput.write(b)
207+
if (shouldSend(b)) {
208+
flush()
209+
}
210+
}
211+
}
212+
213+
override fun flush() {
214+
if (capturedOutput.size() > 0) {
215+
onCaptured(capturedOutput.toString("UTF-8"))
216+
capturedOutput.reset()
217+
}
175218
}
219+
220+
class OutputLimitExceededException(message: String = "Cell output limit exceeded"): Exception(message)
176221
}
177222

178223
fun Any.toMimeTypedResult(): MimeTypedResult? = when (this) {
@@ -186,10 +231,19 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
186231
val out = System.out
187232
val err = System.err
188233

189-
// TODO: make configuration option of whether to pipe back stdout and stderr
190-
// TODO: make a configuration option to limit the total stdout / stderr possibly returned (in case it goes wild...)
191-
val forkedOut = CapturingOutputStream(out, true)
192-
val forkedError = CapturingOutputStream(err, false)
234+
fun getCapturingStream(stream: PrintStream, outType: JupyterOutType): CapturingOutputStream {
235+
return CapturingOutputStream(
236+
stream,
237+
config.cellOutputMaxSize,
238+
config.captureOutput,
239+
config.captureBufferMaxSize,
240+
config.captureBufferTimeLimitMs) { text ->
241+
this.iopub.sendOut(contextMessage!!, outType, text)
242+
}
243+
}
244+
245+
val forkedOut = getCapturingStream(out, JupyterOutType.STDOUT)
246+
val forkedError = getCapturingStream(err, JupyterOutType.STDERR)
193247

194248
System.setOut(PrintStream(forkedOut, true, "UTF-8"))
195249
System.setErr(PrintStream(forkedError, true, "UTF-8"))
@@ -202,26 +256,26 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
202256
if (exec == null) {
203257
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, "NO REPL!")
204258
} else {
205-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
206-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
259+
forkedOut.flush()
260+
forkedError.flush()
207261

208262
try {
209263
var result: MimeTypedResult? = null
210-
var displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }
264+
val displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }.toMutableList()
211265
if (exec.resultValue is DisplayResult) {
212266
val resultDisplay = exec.resultValue.value.toMimeTypedResult()
213267
if (resultDisplay != null)
214268
displays += resultDisplay
215269
} else result = exec.resultValue?.toMimeTypedResult()
216-
ResponseWithMessage(ResponseState.Ok, result, displays, stdOut, stdErr)
270+
ResponseWithMessage(ResponseState.Ok, result, displays, null, null)
217271
} catch (e: Exception) {
218-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
219-
joinLines(stdErr, "error: Unable to convert result to a string: ${e}"))
272+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
273+
"error: Unable to convert result to a string: $e")
220274
}
221275
}
222276
} catch (ex: ReplCompilerException) {
223-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
224-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
277+
forkedOut.flush()
278+
forkedError.flush()
225279

226280
// handle runtime vs. compile time and send back correct format of response, now we just send text
227281
/*
@@ -232,10 +286,10 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
232286
'traceback' : list(str), # traceback frames as strings
233287
}
234288
*/
235-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
236-
joinLines(stdErr, ex.errorResult.message))
289+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
290+
ex.errorResult.message)
237291
} catch (ex: ReplEvalRuntimeException) {
238-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
292+
forkedOut.flush()
239293

240294
// handle runtime vs. compile time and send back correct format of response, now we just send text
241295
/*
@@ -262,7 +316,7 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
262316
}
263317
}
264318
}
265-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, stdErr.toString())
319+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, stdErr.toString())
266320
}
267321
} finally {
268322
System.setIn(`in`)
@@ -271,7 +325,4 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
271325
}
272326
}
273327

274-
fun joinLines(vararg parts: String): String = parts.filter(String::isNotBlank).joinToString("\n")
275328
fun String.nullWhenEmpty(): String? = if (this.isBlank()) null else this
276-
fun String?.emptyWhenNull(): String = if (this == null || this.isBlank()) "" else this
277-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package org.jetbrains.kotlin.jupyter.test
2+
3+
import org.jetbrains.kotlin.jupyter.CapturingOutputStream
4+
import org.junit.Assert.assertArrayEquals
5+
import org.junit.Assert.assertEquals
6+
import org.junit.Test
7+
import java.io.OutputStream
8+
import java.io.PrintStream
9+
10+
class CapturingStreamTests {
11+
private val nullOStream = object: OutputStream() {
12+
override fun write(b: Int) {
13+
}
14+
}
15+
16+
private fun getStream(stdout: OutputStream = nullOStream,
17+
maxOutputSize: Int = 1000,
18+
captureOutput: Boolean = true,
19+
maxBufferSize: Int = 1000,
20+
maxBufferLifeTimeMs: Int = 1000,
21+
onCaptured: (String) -> Unit = {}): CapturingOutputStream {
22+
23+
val printStream = PrintStream(stdout, false, "UTF-8")
24+
return CapturingOutputStream(printStream, maxOutputSize, captureOutput,
25+
maxBufferSize, maxBufferLifeTimeMs, onCaptured)
26+
}
27+
28+
@Test
29+
fun testMaxOutputSizeOk() {
30+
val s = getStream(maxOutputSize = 6)
31+
s.write("kotlin".toByteArray())
32+
}
33+
34+
@Test(expected = CapturingOutputStream.OutputLimitExceededException::class)
35+
fun testMaxOutputSizeError() {
36+
val s = getStream(maxOutputSize = 3)
37+
s.write("java".toByteArray())
38+
}
39+
40+
@Test
41+
fun testOutputCapturingFlag() {
42+
val contents = "abc".toByteArray()
43+
44+
val s1 = getStream(captureOutput = false)
45+
s1.write(contents)
46+
assertEquals(0, s1.capturedOutput.size())
47+
48+
val s2 = getStream(captureOutput = true)
49+
s2.write(contents)
50+
assertArrayEquals(contents, s2.capturedOutput.toByteArray())
51+
}
52+
53+
@Test
54+
fun testMaxBufferSize() {
55+
val contents = "0123456789\nfortran".toByteArray()
56+
val expected = arrayOf("012", "345", "678", "9\n", "for", "tra", "n")
57+
58+
var i = 0
59+
val s = getStream(maxBufferSize = 3) {
60+
assertEquals(expected[i], it)
61+
++i
62+
}
63+
64+
s.write(contents)
65+
s.flush()
66+
67+
assertEquals(expected.size, i)
68+
}
69+
70+
@Test
71+
fun testMaxBufferLifeTime() {
72+
val strings = arrayOf("c ", "a", "ada ", "b", "scala ", "c")
73+
val expected = arrayOf("c a", "ada b", "scala c")
74+
75+
var i = 0
76+
val s = getStream(maxBufferLifeTimeMs = 1000) {
77+
assertEquals(expected[i], it)
78+
++i
79+
}
80+
81+
strings.forEach {
82+
Thread.sleep(600)
83+
s.write(it.toByteArray())
84+
}
85+
86+
s.flush()
87+
88+
assertEquals(expected.size, i)
89+
}
90+
}

0 commit comments

Comments
 (0)