diff --git a/docs/configuration.md b/docs/configuration.md
index 6aa7878fe614..0c7c4472be64 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful
from JVM to Python worker for every task.
-
- spark.sql.repl.eagerEval.enabled |
- false |
-
- Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
- Dataset will be ran automatically. The HTML table which generated by _repl_html_
- called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
- REPL, the output will be shown like dataframe.show()
- (see SPARK-24215 for more details).
- |
-
-
- spark.sql.repl.eagerEval.maxNumRows |
- 20 |
-
- Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text,
- this only take effect when spark.sql.repl.eagerEval.enabled is set to true.
- |
-
-
- spark.sql.repl.eagerEval.truncate |
- 20 |
-
- Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or
- plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true.
- |
-
spark.files |
|
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1e6a1acebb5c..cb3fe448b6fc 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -393,9 +393,8 @@ def _repr_html_(self):
self._support_repr_html = True
if self._eager_eval:
max_num_rows = max(self._max_num_rows, 0)
- vertical = False
sock_info = self._jdf.getRowsToPython(
- max_num_rows, self._truncate, vertical)
+ max_num_rows, self._truncate)
rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
head = rows[0]
row_data = rows[1:]
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 35a0636e5cfc..8d738069adb3 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3351,11 +3351,41 @@ def test_checking_csv_header(self):
finally:
shutil.rmtree(path)
- def test_repr_html(self):
+ def test_repr_behaviors(self):
import re
pattern = re.compile(r'^ *\|', re.MULTILINE)
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
- self.assertEquals(None, df._repr_html_())
+
+ # test when eager evaluation is enabled and _repr_html_ will not be called
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+ expected1 = """+-----+-----+
+ || key|value|
+ |+-----+-----+
+ || 1| 1|
+ ||22222|22222|
+ |+-----+-----+
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+ expected2 = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ || 1| 1|
+ ||222| 222|
+ |+---+-----+
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+ expected3 = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ || 1| 1|
+ |+---+-----+
+ |only showing top 1 row
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
+
+ # test when eager evaluation is enabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
expected1 = """
|| key | value |
@@ -3381,6 +3411,18 @@ def test_repr_html(self):
|"""
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
+ # test when eager evaluation is disabled and _repr_html_ will be called
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
+ expected = "DataFrame[key: bigint, value: string]"
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+
class HiveSparkSubmitTests(SparkSubmitTests):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 239c8266351a..e1752ff997b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1330,6 +1330,29 @@ object SQLConf {
"The size function returns null for null input if the flag is disabled.")
.booleanConf
.createWithDefault(true)
+
+ val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled")
+ .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " +
+ "displayed if and only if the REPL supports the eager evaluation. Currently, the " +
+ "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " +
+ "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " +
+ "the returned outputs are formatted like dataframe.show().")
+ .booleanConf
+ .createWithDefault(false)
+
+ val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows")
+ .doc("The max number of rows that are returned by eager evaluation. This only takes " +
+ "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " +
+ "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " +
+ "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).")
+ .intConf
+ .createWithDefault(20)
+
+ val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate")
+ .doc("The max number of characters for each cell that is returned by eager evaluation. " +
+ "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.")
+ .intConf
+ .createWithDefault(20)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 57f1e173211a..2ec236fc75ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -236,12 +236,10 @@ class Dataset[T] private[sql](
* @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
- * @param vertical If set to true, the rows to return do not need truncate.
*/
private[sql] def getRows(
numRows: Int,
- truncate: Int,
- vertical: Boolean): Seq[Seq[String]] = {
+ truncate: Int): Seq[Seq[String]] = {
val newDf = toDF()
val castCols = newDf.logicalPlan.output.map { col =>
// Since binary types in top-level schema fields have a specific format to print,
@@ -289,7 +287,7 @@ class Dataset[T] private[sql](
vertical: Boolean = false): String = {
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
- val tmpRows = getRows(numRows, truncate, vertical)
+ val tmpRows = getRows(numRows, truncate)
val hasMoreData = tmpRows.length - 1 > numRows
val rows = tmpRows.take(numRows + 1)
@@ -3226,11 +3224,10 @@ class Dataset[T] private[sql](
private[sql] def getRowsToPython(
_numRows: Int,
- truncate: Int,
- vertical: Boolean): Array[Any] = {
+ truncate: Int): Array[Any] = {
EvaluatePython.registerPicklers()
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
- val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+ val rows = getRows(numRows, truncate).map(_.toArray).toArray
val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
rows.iterator.map(toJava))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1cc8cb3874c9..ea00d22bff00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.select($"*").show(1000)
}
+ test("getRows: truncate = [0, 20]") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = Seq(
+ Seq("value"),
+ Seq("1"),
+ Seq("111111111111111111111"))
+ assert(df.getRows(10, 0) === expectedAnswerForFalse)
+ val expectedAnswerForTrue = Seq(
+ Seq("value"),
+ Seq("1"),
+ Seq("11111111111111111..."))
+ assert(df.getRows(10, 20) === expectedAnswerForTrue)
+ }
+
+ test("getRows: truncate = [3, 17]") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = Seq(
+ Seq("value"),
+ Seq("1"),
+ Seq("111"))
+ assert(df.getRows(10, 3) === expectedAnswerForFalse)
+ val expectedAnswerForTrue = Seq(
+ Seq("value"),
+ Seq("1"),
+ Seq("11111111111111..."))
+ assert(df.getRows(10, 17) === expectedAnswerForTrue)
+ }
+
+ test("getRows: numRows = 0") {
+ val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1"))
+ assert(testData.select($"*").getRows(0, 20) === expectedAnswer)
+ }
+
+ test("getRows: array") {
+ val df = Seq(
+ (Array(1, 2, 3), Array(1, 2, 3)),
+ (Array(2, 3, 4), Array(2, 3, 4))
+ ).toDF()
+ val expectedAnswer = Seq(
+ Seq("_1", "_2"),
+ Seq("[1, 2, 3]", "[1, 2, 3]"),
+ Seq("[2, 3, 4]", "[2, 3, 4]"))
+ assert(df.getRows(10, 20) === expectedAnswer)
+ }
+
+ test("getRows: binary") {
+ val df = Seq(
+ ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
+ ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
+ ).toDF()
+ val expectedAnswer = Seq(
+ Seq("_1", "_2"),
+ Seq("[31 32]", "[41 42 43 2E]"),
+ Seq("[33 34]", "[31 32 33 34 36]"))
+ assert(df.getRows(10, 20) === expectedAnswer)
+ }
+
test("showString: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()