Skip to content

Commit bab0f92

Browse files
committed
[#32] Improved output capturing and added tests
1 parent 0815307 commit bab0f92

File tree

10 files changed

+345
-61
lines changed

10 files changed

+345
-61
lines changed

build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ dependencies {
166166
compile 'khttp:khttp:1.0.0'
167167
compile 'org.zeromq:jeromq:0.3.5'
168168
compile 'com.beust:klaxon:5.2'
169+
compile 'com.github.ajalt:clikt:2.3.0'
169170
runtime 'org.slf4j:slf4j-simple:1.7.25'
170171
runtime "org.jetbrains.kotlin:jcabi-aether:1.0-dev-3"
171172
runtime "org.sonatype.aether:aether-api:1.13.1"

src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt

+16
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ enum class JupyterSockets {
4040
iopub
4141
}
4242

43+
data class OutputConfig(
44+
var captureOutput: Boolean = true,
45+
var captureBufferTimeLimitMs: Int = 100,
46+
var captureBufferMaxSize: Int = 1000,
47+
var cellOutputMaxSize: Int = 100000,
48+
var captureNewlineBufferSize: Int = 100
49+
) {
50+
fun assign(other: OutputConfig) {
51+
captureOutput = other.captureOutput
52+
captureBufferTimeLimitMs = other.captureBufferTimeLimitMs
53+
captureBufferMaxSize = other.captureBufferMaxSize
54+
cellOutputMaxSize = other.cellOutputMaxSize
55+
captureNewlineBufferSize = other.captureNewlineBufferSize
56+
}
57+
}
58+
4359
data class RuntimeKernelProperties(val map: Map<String, String>) {
4460
val version: String
4561
get() = map["version"] ?: "unspecified"

src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt

+2-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ import com.beust.klaxon.Parser
55
import org.jetbrains.kotlin.com.intellij.openapi.Disposable
66
import org.jetbrains.kotlin.com.intellij.openapi.util.Disposer
77
import org.zeromq.ZMQ
8-
import java.io.ByteArrayOutputStream
98
import java.io.Closeable
10-
import java.io.PrintStream
119
import java.security.SignatureException
1210
import java.util.*
1311
import javax.crypto.Mac
@@ -20,6 +18,7 @@ class JupyterConnection(val config: KernelConfig): Closeable {
2018
init {
2119
val port = config.ports[socket.ordinal]
2220
bind("${config.transport}://*:$port")
21+
Thread.sleep(200)
2322
log.debug("[$name] listen: ${config.transport}://*:$port")
2423
}
2524

@@ -150,13 +149,9 @@ class HMAC(algo: String, key: String?) {
150149
operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable())
151150
}
152151

153-
fun JupyterConnection.Socket.logWireMessage(msg: ByteArray) {
154-
log.debug("[$name] >in: ${String(msg)}")
155-
}
156-
157152
fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) })
158153

159-
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC): Unit {
154+
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC) {
160155
msg.id.forEach { sendMore(it) }
161156
sendMore(DELIM)
162157
val signableMsg = listOf(msg.header, msg.parentHeader, msg.metadata, msg.content)

src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package org.jetbrains.kotlin.jupyter
22

3+
import com.github.ajalt.clikt.core.CliktCommand
4+
import com.github.ajalt.clikt.parameters.options.default
5+
import com.github.ajalt.clikt.parameters.options.flag
6+
import com.github.ajalt.clikt.parameters.options.option
7+
import com.github.ajalt.clikt.parameters.types.int
38
import org.jetbrains.kotlin.jupyter.repl.spark.ClassWriter
49

510
enum class ReplLineMagics(val desc: String, val argumentsUsage: String? = null, val visibleInHelp: Boolean = true) {
611
use("include supported libraries", "klaxon(5.0.1), lets-plot"),
712
trackClasspath("log current classpath changes"),
813
trackExecution("log code that is going to be executed in repl", visibleInHelp = false),
9-
dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false)
14+
dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false),
15+
output("setup output settings", "--max 1000 --no-stdout --time-interval-ms 100 --buffer-limit 400")
1016
}
1117

1218
fun processMagics(repl: ReplForJupyter, code: String): String {
@@ -15,6 +21,35 @@ fun processMagics(repl: ReplForJupyter, code: String): String {
1521
var nextSearchIndex = 0
1622
var nextCopyIndex = 0
1723

24+
val outputParser = repl.outputConfig.let { conf ->
25+
object : CliktCommand() {
26+
val defaultConfig = OutputConfig()
27+
28+
val max: Int by option("--max-cell-size", help = "Maximum cell output").int().default(conf.cellOutputMaxSize)
29+
val maxBuffer: Int by option("--max-buffer", help = "Maximum buffer size").int().default(conf.captureBufferMaxSize)
30+
val maxBufferNewline: Int by option("--max-buffer-newline", help = "Maximum buffer size when newline got").int().default(conf.captureNewlineBufferSize)
31+
val maxTimeInterval: Int by option("--max-time", help = "Maximum time wait for output to accumulate").int().default(conf.captureBufferTimeLimitMs)
32+
val dontCaptureStdout: Boolean by option("--no-stdout", help = "Don't capture output").flag(default = !conf.captureOutput)
33+
val reset: Boolean by option("--reset-to-defaults", help = "Reset to defaults").flag()
34+
35+
override fun run() {
36+
if (reset) {
37+
conf.assign(defaultConfig)
38+
return
39+
}
40+
conf.assign(
41+
OutputConfig(
42+
!dontCaptureStdout,
43+
maxTimeInterval,
44+
maxBuffer,
45+
max,
46+
maxBufferNewline
47+
)
48+
)
49+
}
50+
}
51+
}
52+
1853
while (true) {
1954

2055
var magicStart: Int
@@ -55,6 +90,9 @@ fun processMagics(repl: ReplForJupyter, code: String): String {
5590
if (arg == null) throw ReplCompilerException("Need some arguments for 'use' command")
5691
repl.librariesCodeGenerator.processNewLibraries(repl, arg)
5792
}
93+
ReplLineMagics.output -> {
94+
outputParser.parse((arg ?: "").split(" "))
95+
}
5896
}
5997
nextCopyIndex = magicEnd
6098
nextSearchIndex = magicEnd

src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt

+83-36
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,25 @@ enum class ResponseState {
1515
Ok, Error
1616
}
1717

18+
enum class JupyterOutType {
19+
STDOUT, STDERR;
20+
fun optionName() = name.toLowerCase()
21+
}
22+
1823
data class ResponseWithMessage(val state: ResponseState, val result: MimeTypedResult?, val displays: List<MimeTypedResult> = emptyList(), val stdOut: String? = null, val stdErr: String? = null) {
1924
val hasStdOut: Boolean = stdOut != null && stdOut.isNotEmpty()
2025
val hasStdErr: Boolean = stdErr != null && stdErr.isNotEmpty()
2126
}
2227

28+
fun JupyterConnection.Socket.sendOut(msg:Message, stream: JupyterOutType, text: String) {
29+
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
30+
content = jsonObject(
31+
"name" to stream.optionName(),
32+
"text" to text)))
33+
}
34+
2335
fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJupyter?, executionCount: AtomicLong) {
24-
val msgType = msg.header!!["msg_type"]
25-
when (msgType) {
36+
when (msg.header!!["msg_type"]) {
2637
"kernel_info_request" ->
2738
sendWrapped(msg, makeReplyMessage(msg, "kernel_info_reply",
2839
content = jsonObject(
@@ -70,23 +81,16 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
7081
val res: ResponseWithMessage = if (isCommand(code.toString())) {
7182
runCommand(code.toString(), repl)
7283
} else {
73-
connection.evalWithIO {
84+
connection.evalWithIO (repl?.outputConfig) {
7485
repl?.eval(code.toString(), count.toInt())
7586
}
7687
}
7788

78-
fun sendOut(stream: String, text: String) {
79-
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
80-
content = jsonObject(
81-
"name" to stream,
82-
"text" to text)))
83-
}
84-
8589
if (res.hasStdOut) {
86-
sendOut("stdout", res.stdOut!!)
90+
sendOut(msg, JupyterOutType.STDOUT, res.stdOut!!)
8791
}
8892
if (res.hasStdErr) {
89-
sendOut("stderr", res.stdErr!!)
93+
sendOut(msg, JupyterOutType.STDERR, res.stdErr!!)
9094
}
9195

9296
when (res.state) {
@@ -169,12 +173,50 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
169173
}
170174
}
171175

172-
class CapturingOutputStream(val stdout: PrintStream, val captureOutput: Boolean) : OutputStream() {
176+
class CapturingOutputStream(private val stdout: PrintStream,
177+
private val conf: OutputConfig,
178+
private val captureOutput: Boolean,
179+
val onCaptured: (String) -> Unit) : OutputStream() {
173180
val capturedOutput = ByteArrayOutputStream()
181+
private var time = System.currentTimeMillis()
182+
private var overallOutputSize = 0
183+
private var newlineFound = false
184+
185+
private fun shouldSend(b: Int): Boolean {
186+
val c = b.toChar()
187+
newlineFound = newlineFound || c == '\n' || c == '\r'
188+
if (newlineFound && capturedOutput.size() >= conf.captureNewlineBufferSize)
189+
return true
190+
if (capturedOutput.size() >= conf.captureBufferMaxSize)
191+
return true
192+
193+
val currentTime = System.currentTimeMillis()
194+
if (currentTime - time >= conf.captureBufferTimeLimitMs) {
195+
time = currentTime
196+
return true
197+
}
198+
return false
199+
}
174200

175201
override fun write(b: Int) {
202+
++overallOutputSize
176203
stdout.write(b)
177-
if (captureOutput) capturedOutput.write(b)
204+
205+
if (captureOutput && overallOutputSize <= conf.cellOutputMaxSize) {
206+
capturedOutput.write(b)
207+
if (shouldSend(b)) {
208+
flush()
209+
}
210+
}
211+
}
212+
213+
override fun flush() {
214+
newlineFound = false
215+
if (capturedOutput.size() > 0) {
216+
val str = capturedOutput.toString("UTF-8")
217+
capturedOutput.reset()
218+
onCaptured(str)
219+
}
178220
}
179221
}
180222

@@ -185,17 +227,25 @@ fun Any.toMimeTypedResult(): MimeTypedResult? = when (this) {
185227
else -> textResult(this.toString())
186228
}
187229

188-
fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
230+
fun JupyterConnection.evalWithIO(maybeConfig: OutputConfig?, body: () -> EvalResult?): ResponseWithMessage {
189231
val out = System.out
190232
val err = System.err
233+
val config = maybeConfig ?: OutputConfig()
191234

192-
// TODO: make configuration option of whether to pipe back stdout and stderr
193-
// TODO: make a configuration option to limit the total stdout / stderr possibly returned (in case it goes wild...)
194-
val forkedOut = CapturingOutputStream(out, true)
195-
val forkedError = CapturingOutputStream(err, false)
235+
fun getCapturingStream(stream: PrintStream, outType: JupyterOutType, captureOutput: Boolean): CapturingOutputStream {
236+
return CapturingOutputStream(
237+
stream,
238+
config,
239+
captureOutput) { text ->
240+
this.iopub.sendOut(contextMessage!!, outType, text)
241+
}
242+
}
196243

197-
System.setOut(PrintStream(forkedOut, true, "UTF-8"))
198-
System.setErr(PrintStream(forkedError, true, "UTF-8"))
244+
val forkedOut = getCapturingStream(out, JupyterOutType.STDOUT, config.captureOutput)
245+
val forkedError = getCapturingStream(err, JupyterOutType.STDERR, false)
246+
247+
System.setOut(PrintStream(forkedOut, false, "UTF-8"))
248+
System.setErr(PrintStream(forkedError, false, "UTF-8"))
199249

200250
val `in` = System.`in`
201251
System.setIn(stdinIn)
@@ -205,26 +255,26 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
205255
if (exec == null) {
206256
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, "NO REPL!")
207257
} else {
208-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
209-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
258+
forkedOut.flush()
259+
forkedError.flush()
210260

211261
try {
212262
var result: MimeTypedResult? = null
213-
var displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }
263+
val displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }.toMutableList()
214264
if (exec.resultValue is DisplayResult) {
215265
val resultDisplay = exec.resultValue.value.toMimeTypedResult()
216266
if (resultDisplay != null)
217267
displays += resultDisplay
218268
} else result = exec.resultValue?.toMimeTypedResult()
219-
ResponseWithMessage(ResponseState.Ok, result, displays, stdOut, stdErr)
269+
ResponseWithMessage(ResponseState.Ok, result, displays, null, null)
220270
} catch (e: Exception) {
221-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
222-
joinLines(stdErr, "error: Unable to convert result to a string: ${e}"))
271+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
272+
"error: Unable to convert result to a string: $e")
223273
}
224274
}
225275
} catch (ex: ReplCompilerException) {
226-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
227-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
276+
forkedOut.flush()
277+
forkedError.flush()
228278

229279
// handle runtime vs. compile time and send back correct format of response, now we just send text
230280
/*
@@ -235,10 +285,10 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
235285
'traceback' : list(str), # traceback frames as strings
236286
}
237287
*/
238-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
239-
joinLines(stdErr, ex.errorResult.message))
288+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
289+
ex.errorResult.message)
240290
} catch (ex: ReplEvalRuntimeException) {
241-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
291+
forkedOut.flush()
242292

243293
// handle runtime vs. compile time and send back correct format of response, now we just send text
244294
/*
@@ -265,7 +315,7 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
265315
}
266316
}
267317
}
268-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, stdErr.toString())
318+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, stdErr.toString())
269319
}
270320
} finally {
271321
System.setIn(`in`)
@@ -274,7 +324,4 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
274324
}
275325
}
276326

277-
fun joinLines(vararg parts: String): String = parts.filter(String::isNotBlank).joinToString("\n")
278327
fun String.nullWhenEmpty(): String? = if (this.isBlank()) null else this
279-
fun String?.emptyWhenNull(): String = if (this == null || this.isBlank()) "" else this
280-

src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class ReplCompilerException(val errorResult: ReplCompileResult.Error) : ReplExce
4040
class ReplForJupyter(val scriptClasspath: List<File> = emptyList(),
4141
val config: ResolverConfig? = null) {
4242

43+
val outputConfig = OutputConfig()
44+
4345
private val resolver = JupyterScriptDependenciesResolver(config)
4446

4547
private val renderers = config?.let {

0 commit comments

Comments
 (0)