Skip to content

Commit 2e53f00

Browse files
committed
[#32] Improved output capturing and added tests
1 parent 65f8618 commit 2e53f00

File tree

10 files changed

+345
-61
lines changed

10 files changed

+345
-61
lines changed

build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ dependencies {
148148
compile 'khttp:khttp:1.0.0'
149149
compile 'org.zeromq:jeromq:0.3.5'
150150
compile 'com.beust:klaxon:5.2'
151+
compile 'com.github.ajalt:clikt:2.3.0'
151152
runtime 'org.slf4j:slf4j-simple:1.7.25'
152153
runtime "org.jetbrains.kotlin:jcabi-aether:1.0-dev-3"
153154
runtime "org.sonatype.aether:aether-api:1.13.1"

src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt

+16
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ enum class JupyterSockets {
4040
iopub
4141
}
4242

43+
data class OutputConfig(
44+
var captureOutput: Boolean = true,
45+
var captureBufferTimeLimitMs: Int = 100,
46+
var captureBufferMaxSize: Int = 1000,
47+
var cellOutputMaxSize: Int = 100000,
48+
var captureNewlineBufferSize: Int = 100
49+
) {
50+
fun assign(other: OutputConfig) {
51+
captureOutput = other.captureOutput
52+
captureBufferTimeLimitMs = other.captureBufferTimeLimitMs
53+
captureBufferMaxSize = other.captureBufferMaxSize
54+
cellOutputMaxSize = other.cellOutputMaxSize
55+
captureNewlineBufferSize = other.captureNewlineBufferSize
56+
}
57+
}
58+
4359
data class KernelConfig(
4460
val ports: Array<Int>,
4561
val transport: String,

src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt

+2-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ import com.beust.klaxon.Parser
55
import org.jetbrains.kotlin.com.intellij.openapi.Disposable
66
import org.jetbrains.kotlin.com.intellij.openapi.util.Disposer
77
import org.zeromq.ZMQ
8-
import java.io.ByteArrayOutputStream
98
import java.io.Closeable
10-
import java.io.PrintStream
119
import java.security.SignatureException
1210
import java.util.*
1311
import javax.crypto.Mac
@@ -20,6 +18,7 @@ class JupyterConnection(val config: KernelConfig): Closeable {
2018
init {
2119
val port = config.ports[socket.ordinal]
2220
bind("${config.transport}://*:$port")
21+
Thread.sleep(200)
2322
log.debug("[$name] listen: ${config.transport}://*:$port")
2423
}
2524

@@ -150,13 +149,9 @@ class HMAC(algo: String, key: String?) {
150149
operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable())
151150
}
152151

153-
fun JupyterConnection.Socket.logWireMessage(msg: ByteArray) {
154-
log.debug("[$name] >in: ${String(msg)}")
155-
}
156-
157152
fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) })
158153

159-
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC): Unit {
154+
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC) {
160155
msg.id.forEach { sendMore(it) }
161156
sendMore(DELIM)
162157
val signableMsg = listOf(msg.header, msg.parentHeader, msg.metadata, msg.content)

src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package org.jetbrains.kotlin.jupyter
22

3+
import com.github.ajalt.clikt.core.CliktCommand
4+
import com.github.ajalt.clikt.parameters.options.default
5+
import com.github.ajalt.clikt.parameters.options.flag
6+
import com.github.ajalt.clikt.parameters.options.option
7+
import com.github.ajalt.clikt.parameters.types.int
38
import org.jetbrains.kotlin.jupyter.repl.spark.ClassWriter
49

510
enum class ReplLineMagics(val desc: String, val argumentsUsage: String? = null, val visibleInHelp: Boolean = true) {
611
use("include supported libraries", "klaxon(5.0.1), lets-plot"),
712
trackClasspath("log current classpath changes"),
813
trackExecution("log code that is going to be executed in repl", visibleInHelp = false),
9-
dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false)
14+
dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false),
15+
output("setup output settings", "--max 1000 --no-stdout --time-interval-ms 100 --buffer-limit 400")
1016
}
1117

1218
fun processMagics(repl: ReplForJupyter, code: String): String {
@@ -15,6 +21,35 @@ fun processMagics(repl: ReplForJupyter, code: String): String {
1521
var nextSearchIndex = 0
1622
var nextCopyIndex = 0
1723

24+
val outputParser = repl.outputConfig.let { conf ->
25+
object : CliktCommand() {
26+
val defaultConfig = OutputConfig()
27+
28+
val max: Int by option("--max-cell-size", help = "Maximum cell output").int().default(conf.cellOutputMaxSize)
29+
val maxBuffer: Int by option("--max-buffer", help = "Maximum buffer size").int().default(conf.captureBufferMaxSize)
30+
val maxBufferNewline: Int by option("--max-buffer-newline", help = "Maximum buffer size when newline got").int().default(conf.captureNewlineBufferSize)
31+
val maxTimeInterval: Int by option("--max-time", help = "Maximum time wait for output to accumulate").int().default(conf.captureBufferTimeLimitMs)
32+
val dontCaptureStdout: Boolean by option("--no-stdout", help = "Don't capture output").flag(default = !conf.captureOutput)
33+
val reset: Boolean by option("--reset-to-defaults", help = "Reset to defaults").flag()
34+
35+
override fun run() {
36+
if (reset) {
37+
conf.assign(defaultConfig)
38+
return
39+
}
40+
conf.assign(
41+
OutputConfig(
42+
!dontCaptureStdout,
43+
maxTimeInterval,
44+
maxBuffer,
45+
max,
46+
maxBufferNewline
47+
)
48+
)
49+
}
50+
}
51+
}
52+
1853
while (true) {
1954

2055
var magicStart: Int
@@ -55,6 +90,9 @@ fun processMagics(repl: ReplForJupyter, code: String): String {
5590
if (arg == null) throw ReplCompilerException("Need some arguments for 'use' command")
5691
repl.librariesCodeGenerator.processNewLibraries(repl, arg)
5792
}
93+
ReplLineMagics.output -> {
94+
outputParser.parse((arg ?: "").split(" "))
95+
}
5896
}
5997
nextCopyIndex = magicEnd
6098
nextSearchIndex = magicEnd

src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt

+83-36
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,25 @@ enum class ResponseState {
1515
Ok, Error
1616
}
1717

18+
enum class JupyterOutType {
19+
STDOUT, STDERR;
20+
fun optionName() = name.toLowerCase()
21+
}
22+
1823
data class ResponseWithMessage(val state: ResponseState, val result: MimeTypedResult?, val displays: List<MimeTypedResult> = emptyList(), val stdOut: String? = null, val stdErr: String? = null) {
1924
val hasStdOut: Boolean = stdOut != null && stdOut.isNotEmpty()
2025
val hasStdErr: Boolean = stdErr != null && stdErr.isNotEmpty()
2126
}
2227

28+
fun JupyterConnection.Socket.sendOut(msg:Message, stream: JupyterOutType, text: String) {
29+
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
30+
content = jsonObject(
31+
"name" to stream.optionName(),
32+
"text" to text)))
33+
}
34+
2335
fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJupyter?, executionCount: AtomicLong) {
24-
val msgType = msg.header!!["msg_type"]
25-
when (msgType) {
36+
when (msg.header!!["msg_type"]) {
2637
"kernel_info_request" ->
2738
sendWrapped(msg, makeReplyMessage(msg, "kernel_info_reply",
2839
content = jsonObject(
@@ -67,23 +78,16 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
6778
val res: ResponseWithMessage = if (isCommand(code.toString())) {
6879
runCommand(code.toString(), repl)
6980
} else {
70-
connection.evalWithIO {
81+
connection.evalWithIO (repl?.outputConfig) {
7182
repl?.eval(code.toString(), count.toInt())
7283
}
7384
}
7485

75-
fun sendOut(stream: String, text: String) {
76-
connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg),
77-
content = jsonObject(
78-
"name" to stream,
79-
"text" to text)))
80-
}
81-
8286
if (res.hasStdOut) {
83-
sendOut("stdout", res.stdOut!!)
87+
sendOut(msg, JupyterOutType.STDOUT, res.stdOut!!)
8488
}
8589
if (res.hasStdErr) {
86-
sendOut("stderr", res.stdErr!!)
90+
sendOut(msg, JupyterOutType.STDERR, res.stdErr!!)
8791
}
8892

8993
when (res.state) {
@@ -166,12 +170,50 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
166170
}
167171
}
168172

169-
class CapturingOutputStream(val stdout: PrintStream, val captureOutput: Boolean) : OutputStream() {
173+
class CapturingOutputStream(private val stdout: PrintStream,
174+
private val conf: OutputConfig,
175+
private val captureOutput: Boolean,
176+
val onCaptured: (String) -> Unit) : OutputStream() {
170177
val capturedOutput = ByteArrayOutputStream()
178+
private var time = System.currentTimeMillis()
179+
private var overallOutputSize = 0
180+
private var newlineFound = false
181+
182+
private fun shouldSend(b: Int): Boolean {
183+
val c = b.toChar()
184+
newlineFound = newlineFound || c == '\n' || c == '\r'
185+
if (newlineFound && capturedOutput.size() >= conf.captureNewlineBufferSize)
186+
return true
187+
if (capturedOutput.size() >= conf.captureBufferMaxSize)
188+
return true
189+
190+
val currentTime = System.currentTimeMillis()
191+
if (currentTime - time >= conf.captureBufferTimeLimitMs) {
192+
time = currentTime
193+
return true
194+
}
195+
return false
196+
}
171197

172198
override fun write(b: Int) {
199+
++overallOutputSize
173200
stdout.write(b)
174-
if (captureOutput) capturedOutput.write(b)
201+
202+
if (captureOutput && overallOutputSize <= conf.cellOutputMaxSize) {
203+
capturedOutput.write(b)
204+
if (shouldSend(b)) {
205+
flush()
206+
}
207+
}
208+
}
209+
210+
override fun flush() {
211+
newlineFound = false
212+
if (capturedOutput.size() > 0) {
213+
val str = capturedOutput.toString("UTF-8")
214+
capturedOutput.reset()
215+
onCaptured(str)
216+
}
175217
}
176218
}
177219

@@ -182,17 +224,25 @@ fun Any.toMimeTypedResult(): MimeTypedResult? = when (this) {
182224
else -> textResult(this.toString())
183225
}
184226

185-
fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
227+
fun JupyterConnection.evalWithIO(maybeConfig: OutputConfig?, body: () -> EvalResult?): ResponseWithMessage {
186228
val out = System.out
187229
val err = System.err
230+
val config = maybeConfig ?: OutputConfig()
188231

189-
// TODO: make configuration option of whether to pipe back stdout and stderr
190-
// TODO: make a configuration option to limit the total stdout / stderr possibly returned (in case it goes wild...)
191-
val forkedOut = CapturingOutputStream(out, true)
192-
val forkedError = CapturingOutputStream(err, false)
232+
fun getCapturingStream(stream: PrintStream, outType: JupyterOutType, captureOutput: Boolean): CapturingOutputStream {
233+
return CapturingOutputStream(
234+
stream,
235+
config,
236+
captureOutput) { text ->
237+
this.iopub.sendOut(contextMessage!!, outType, text)
238+
}
239+
}
193240

194-
System.setOut(PrintStream(forkedOut, true, "UTF-8"))
195-
System.setErr(PrintStream(forkedError, true, "UTF-8"))
241+
val forkedOut = getCapturingStream(out, JupyterOutType.STDOUT, config.captureOutput)
242+
val forkedError = getCapturingStream(err, JupyterOutType.STDERR, false)
243+
244+
System.setOut(PrintStream(forkedOut, false, "UTF-8"))
245+
System.setErr(PrintStream(forkedError, false, "UTF-8"))
196246

197247
val `in` = System.`in`
198248
System.setIn(stdinIn)
@@ -202,26 +252,26 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
202252
if (exec == null) {
203253
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, "NO REPL!")
204254
} else {
205-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
206-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
255+
forkedOut.flush()
256+
forkedError.flush()
207257

208258
try {
209259
var result: MimeTypedResult? = null
210-
var displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }
260+
val displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }.toMutableList()
211261
if (exec.resultValue is DisplayResult) {
212262
val resultDisplay = exec.resultValue.value.toMimeTypedResult()
213263
if (resultDisplay != null)
214264
displays += resultDisplay
215265
} else result = exec.resultValue?.toMimeTypedResult()
216-
ResponseWithMessage(ResponseState.Ok, result, displays, stdOut, stdErr)
266+
ResponseWithMessage(ResponseState.Ok, result, displays, null, null)
217267
} catch (e: Exception) {
218-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
219-
joinLines(stdErr, "error: Unable to convert result to a string: ${e}"))
268+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
269+
"error: Unable to convert result to a string: $e")
220270
}
221271
}
222272
} catch (ex: ReplCompilerException) {
223-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
224-
val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull()
273+
forkedOut.flush()
274+
forkedError.flush()
225275

226276
// handle runtime vs. compile time and send back correct format of response, now we just send text
227277
/*
@@ -232,10 +282,10 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
232282
'traceback' : list(str), # traceback frames as strings
233283
}
234284
*/
235-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut,
236-
joinLines(stdErr, ex.errorResult.message))
285+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null,
286+
ex.errorResult.message)
237287
} catch (ex: ReplEvalRuntimeException) {
238-
val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull()
288+
forkedOut.flush()
239289

240290
// handle runtime vs. compile time and send back correct format of response, now we just send text
241291
/*
@@ -262,7 +312,7 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
262312
}
263313
}
264314
}
265-
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, stdErr.toString())
315+
ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, stdErr.toString())
266316
}
267317
} finally {
268318
System.setIn(`in`)
@@ -271,7 +321,4 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage {
271321
}
272322
}
273323

274-
fun joinLines(vararg parts: String): String = parts.filter(String::isNotBlank).joinToString("\n")
275324
fun String.nullWhenEmpty(): String? = if (this.isBlank()) null else this
276-
fun String?.emptyWhenNull(): String = if (this == null || this.isBlank()) "" else this
277-

src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class ReplCompilerException(val errorResult: ReplCompileResult.Error) : ReplExce
4040
class ReplForJupyter(val scriptClasspath: List<File> = emptyList(),
4141
val config: ResolverConfig? = null) {
4242

43+
val outputConfig = OutputConfig()
44+
4345
private val resolver = JupyterScriptDependenciesResolver(config)
4446

4547
private val renderers = config?.let {

0 commit comments

Comments
 (0)