diff --git a/docs/configuration.md b/docs/configuration.md
index 64af0e98a82f..5588c372d3e4 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -456,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful
from JVM to Python worker for every task.
+
+ spark.sql.repl.eagerEval.enabled |
+ false |
+
+ Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
+ Dataset will be ran automatically. The HTML table which generated by _repl_html_
+ called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
+ REPL, the output will be shown like dataframe.show()
+ (see SPARK-24215 for more details).
+ |
+
+
+ spark.sql.repl.eagerEval.maxNumRows |
+ 20 |
+
+ Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text,
+ this only take effect when spark.sql.repl.eagerEval.enabled is set to true.
+ |
+
+
+ spark.sql.repl.eagerEval.truncate |
+ 20 |
+
+ Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or
+ plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true.
+ |
+
spark.files |
|
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 808235ab2544..1e6a1acebb5c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx):
self.is_cached = False
self._schema = None # initialized lazily
self._lazy_rdd = None
+ # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opened.
+ self._support_repr_html = False
@property
@since(1.3)
@@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False):
else:
print(self._jdf.showString(n, int(truncate), vertical))
+ @property
+ def _eager_eval(self):
+ """Returns true if the eager evaluation enabled.
+ """
+ return self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.enabled", "false").lower() == "true"
+
+ @property
+ def _max_num_rows(self):
+ """Returns the max row number for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.maxNumRows", "20"))
+
+ @property
+ def _truncate(self):
+ """Returns the truncate length for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.truncate", "20"))
+
def __repr__(self):
- return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+ if not self._support_repr_html and self._eager_eval:
+ vertical = False
+ return self._jdf.showString(
+ self._max_num_rows, self._truncate, vertical)
+ else:
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+
+ def _repr_html_(self):
+ """Returns a dataframe with html code when you enabled eager evaluation
+ by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+ using support eager evaluation with HTML.
+ """
+ import cgi
+ if not self._support_repr_html:
+ self._support_repr_html = True
+ if self._eager_eval:
+ max_num_rows = max(self._max_num_rows, 0)
+ vertical = False
+ sock_info = self._jdf.getRowsToPython(
+ max_num_rows, self._truncate, vertical)
+ rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
+ head = rows[0]
+ row_data = rows[1:]
+ has_more_data = len(row_data) > max_num_rows
+ row_data = row_data[:max_num_rows]
+
+ html = "\n"
+ # generate table head
+ html += "| %s |
\n" % "".join(map(lambda x: cgi.escape(x), head))
+ # generate table rows
+ for row in row_data:
+ html += " | | %s |
\n" % "".join(
+ map(lambda x: cgi.escape(x), row))
+ html += " |
\n"
+ if has_more_data:
+ html += "only showing top %d %s\n" % (
+ max_num_rows, "row" if max_num_rows == 1 else "rows")
+ return html
+ else:
+ return None
@since(2.1)
def checkpoint(self, eager=True):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ea2dd7605dc5..487eb19c3b98 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3074,6 +3074,36 @@ def test_checking_csv_header(self):
finally:
shutil.rmtree(path)
+ def test_repr_html(self):
+ import re
+ pattern = re.compile(r'^ *\|', re.MULTILINE)
+ df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
+ self.assertEquals(None, df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+ expected1 = """
+ || key | value |
+ || 1 | 1 |
+ || 22222 | 22222 |
+ |
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+ expected2 = """
+ || key | value |
+ || 1 | 1 |
+ || 222 | 222 |
+ |
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+ expected3 = """
+ |only showing top 1 row
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
+
class HiveSparkSubmitTests(SparkSubmitTests):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index abb5ae53f4d7..f5526104690d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -231,16 +231,17 @@ class Dataset[T] private[sql](
}
/**
- * Compose the string representing rows for output
+ * Get rows represented in Sequence by specific truncate and vertical requirement.
*
- * @param _numRows Number of rows to show
+ * @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
- * @param vertical If set to true, prints output rows vertically (one line per column value).
+ * @param vertical If set to true, the rows to return do not need truncate.
*/
- private[sql] def showString(
- _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
- val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ private[sql] def getRows(
+ numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Seq[Seq[String]] = {
val newDf = toDF()
val castCols = newDf.logicalPlan.output.map { col =>
// Since binary types in top-level schema fields have a specific format to print,
@@ -251,14 +252,12 @@ class Dataset[T] private[sql](
Column(col).cast(StringType)
}
}
- val takeResult = newDf.select(castCols: _*).take(numRows + 1)
- val hasMoreData = takeResult.length > numRows
- val data = takeResult.take(numRows)
+ val data = newDf.select(castCols: _*).take(numRows + 1)
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ schema.fieldNames.toSeq +: data.map { row =>
row.toSeq.map { cell =>
val str = cell match {
case null => "null"
@@ -274,6 +273,26 @@ class Dataset[T] private[sql](
}
}: Seq[String]
}
+ }
+
+ /**
+ * Compose the string representing rows for output
+ *
+ * @param _numRows Number of rows to show
+ * @param truncate If set to more than 0, truncates strings to `truncate` characters and
+ * all cells will be aligned right.
+ * @param vertical If set to true, prints output rows vertically (one line per column value).
+ */
+ private[sql] def showString(
+ _numRows: Int,
+ truncate: Int = 20,
+ vertical: Boolean = false): String = {
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
+ val tmpRows = getRows(numRows, truncate, vertical)
+
+ val hasMoreData = tmpRows.length - 1 > numRows
+ val rows = tmpRows.take(numRows + 1)
val sb = new StringBuilder
val numCols = schema.fieldNames.length
@@ -291,31 +310,25 @@ class Dataset[T] private[sql](
}
}
+ val paddedRows = rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ if (truncate > 0) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }
+ }
+
// Create SeparateLine
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
// column names
- rows.head.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell, colWidths(i))
- } else {
- StringUtils.rightPad(cell, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
-
+ paddedRows.head.addString(sb, "|", "|", "|\n")
sb.append(sep)
// data
- rows.tail.foreach {
- _.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell.toString, colWidths(i))
- } else {
- StringUtils.rightPad(cell.toString, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
- }
-
+ paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
sb.append(sep)
} else {
// Extended display mode enabled
@@ -346,7 +359,7 @@ class Dataset[T] private[sql](
}
// Print a footer
- if (vertical && data.isEmpty) {
+ if (vertical && rows.tail.isEmpty) {
// In a vertical mode, print an empty row set explicitly
sb.append("(0 rows)\n")
} else if (hasMoreData) {
@@ -3209,6 +3222,19 @@ class Dataset[T] private[sql](
}
}
+ private[sql] def getRowsToPython(
+ _numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Array[Any] = {
+ EvaluatePython.registerPicklers()
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+ val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
+ val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+ rows.iterator.map(toJava))
+ PythonRDD.serveIterator(iter, "serve-GetRows")
+ }
+
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/