Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2344,29 +2344,24 @@ private[spark] class RedirectThread(
* the toString method.
*/
private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
var pos: Int = 0
var buffer = new Array[Int](sizeInBytes)
private var pos: Int = 0
private var isBufferFull = false
private val buffer = new Array[Byte](sizeInBytes)

def write(i: Int): Unit = {
buffer(pos) = i
def write(input: Int): Unit = {
buffer(pos) = input.toByte
pos = (pos + 1) % buffer.length
isBufferFull = isBufferFull || (pos == 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can use ||=, but, don't know if it's clearer

}

override def toString: String = {
val (end, start) = buffer.splitAt(pos)
val input = new java.io.InputStream {
val iterator = (start ++ end).iterator

def read(): Int = if (iterator.hasNext) iterator.next() else -1
}
val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8))
val stringBuilder = new StringBuilder
var line = reader.readLine()
while (line != null) {
stringBuilder.append(line)
stringBuilder.append("\n")
line = reader.readLine()
if (!isBufferFull) {
return new String(buffer, 0, pos, StandardCharsets.UTF_8)
}
stringBuilder.toString()

val nonCircularBuffer = new Array[Byte](sizeInBytes)
System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos)
System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos)
new String(nonCircularBuffer, StandardCharsets.UTF_8)
}
}
37 changes: 30 additions & 7 deletions core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream}
import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
Expand Down Expand Up @@ -681,14 +681,37 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(!Utils.isInDirectory(nullFile, childFile3))
}

test("circular buffer") {
test("circular buffer: if nothing was written to the buffer, display nothing") {
val buffer = new CircularBuffer(4)
assert(buffer.toString === "")
}

test("circular buffer: if the buffer isn't full, print only the contents written") {
val buffer = new CircularBuffer(10)
val stream = new PrintStream(buffer, true, "UTF-8")
stream.print("test")
assert(buffer.toString === "test")
}

test("circular buffer: data written == size of the buffer") {
val buffer = new CircularBuffer(4)
val stream = new PrintStream(buffer, true, "UTF-8")

// fill the buffer to its exact size so that it just hits overflow
stream.print("test")
assert(buffer.toString === "test")

// add more data to the buffer
stream.print("12")
assert(buffer.toString === "st12")
}

test("circular buffer: multiple overflow") {
val buffer = new CircularBuffer(25)
val stream = new java.io.PrintStream(buffer, true, "UTF-8")
val stream = new PrintStream(buffer, true, "UTF-8")

// scalastyle:off println
stream.println("test circular test circular test circular test circular test circular")
// scalastyle:on println
assert(buffer.toString === "t circular test circular\n")
stream.print("test circular test circular test circular test circular test circular")
assert(buffer.toString === "st circular test circular")
}

test("nanSafeCompareDoubles") {
Expand Down