From e67c910a5777300f1dc6d9c4908c0794dcd12863 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Tue, 20 May 2014 12:24:47 -0700 Subject: [PATCH 1/3] [SPARK-1822] SchemaRDD.count() should use optimizer --- python/pyspark/sql.py | 14 +++++++++++++- .../sql/catalyst/expressions/aggregates.scala | 6 +++--- .../scala/org/apache/spark/sql/SchemaRDD.scala | 4 ++++ .../scala/org/apache/spark/sql/DslQuerySuite.scala | 9 +++++---- .../test/scala/org/apache/spark/sql/TestData.scala | 3 +++ 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index bbe69e7d8f89..f2001afae4ee 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -268,7 +268,7 @@ def __init__(self, jschema_rdd, sql_ctx): def _jrdd(self): """ Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, count, etc.). + L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): self._lazy_jrdd = self._toPython()._jrdd @@ -321,6 +321,18 @@ def saveAsTable(self, tableName): """ self._jschema_rdd.saveAsTable(tableName) + def count(self): + """ + Return the number of elements in this RDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.count() + 3L + >>> srdd.count() == srdd.map(lambda x: x).count() + True + """ + return self._jschema_rdd.count() + def _toPython(self): # We have to import the Row class explicitly, so that the reference Pickler has is # pyspark.sql.Row instead of __main__.Row diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5dbaaa3b0ce3..1bcd4e22766a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -151,7 +151,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def references = child.references override def nullable = false - override def dataType = IntegerType + override def dataType = LongType override def toString = s"COUNT($child)" override def asPartial: SplitEvaluation = { @@ -295,12 +295,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - var count: Int = _ + var count: Long = _ override def update(input: Row): Unit = { val evaluatedExpr = expr.map(_.eval(input)) if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { - count += 1 + count += 1L } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 2569815ebb20..fb996434e3ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -274,6 +274,10 @@ class SchemaRDD( seed: Long) = new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) + override def count(): Long = { + groupBy()(Count(Literal(1))).collect().head.getLong(0) + } + /** * :: Experimental :: * Applies the given Generator, or table generating function, to this relation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index f43e98d61409..233132a2fec6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -108,10 +108,7 @@ class DslQuerySuite extends QueryTest { } test("count") { - checkAnswer( - testData2.groupBy()(Count(1)), - testData2.count() - ) + assert(testData2.count() === testData2.map(_ => 1).count()) } test("null count") { @@ -126,6 +123,10 @@ class DslQuerySuite extends QueryTest { ) } + test("zero count") { + assert(testData4.count() === 0) + } + test("inner join where, one match per row") { checkAnswer( upperCaseData.join(lowerCaseData, Inner).where('n === 'N), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 1aca3872524d..254c3b199642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,6 +47,9 @@ object TestData { (1, null) :: (2, 2) :: Nil) + val testData4 = + logical.LocalRelation('a.int, 'b.int) + case class UpperCaseData(N: Int, L: String) val upperCaseData = TestSQLContext.sparkContext.parallelize( From cf4baa4abc3cd93e56cb57e403a8ebb3b2c6da56 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Fri, 23 May 2014 12:34:08 -0700 Subject: [PATCH 2/3] [SPARK-1822] Adding Scaladoc --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index fb996434e3ba..452da3d02310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -274,6 +274,11 @@ class SchemaRDD( seed: Long) = new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) + /** + * :: Experimental :: + * Overriding base RDD implementation to leverage query optimizer + */ + @Experimental override def count(): Long = { groupBy()(Count(Literal(1))).collect().head.getLong(0) } From 2f8072a1f0937747052458cfaf63a0d10d930e8a Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Fri, 23 May 2014 12:42:17 -0700 Subject: [PATCH 3/3] [SPARK-1822] Minor style update --- sql/core/src/test/scala/org/apache/spark/sql/TestData.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 254c3b199642..b1eecb4dd3be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,8 +47,7 @@ object TestData { (1, null) :: (2, 2) :: Nil) - val testData4 = - logical.LocalRelation('a.int, 'b.int) + val testData4 = logical.LocalRelation('a.int, 'b.int) case class UpperCaseData(N: Int, L: String) val upperCaseData =