Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, jschema_rdd, sql_ctx):
def _jrdd(self):
"""
Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
L{pyspark.rdd.RDD} super class (map, count, etc.).
L{pyspark.rdd.RDD} super class (map, filter, etc.).
"""
if not hasattr(self, '_lazy_jrdd'):
self._lazy_jrdd = self._toPython()._jrdd
Expand Down Expand Up @@ -321,6 +321,18 @@ def saveAsTable(self, tableName):
"""
self._jschema_rdd.saveAsTable(tableName)

def count(self):
"""
Return the number of elements in this RDD.

>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.count()
3L
>>> srdd.count() == srdd.map(lambda x: x).count()
True
"""
return self._jschema_rdd.count()

def _toPython(self):
# We have to import the Row class explicitly, so that the reference Pickler has is
# pyspark.sql.Row instead of __main__.Row
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def dataType = LongType
override def toString = s"COUNT($child)"

override def asPartial: SplitEvaluation = {
Expand Down Expand Up @@ -295,12 +295,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

var count: Int = _
var count: Long = _

override def update(input: Row): Unit = {
val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
count += 1
count += 1L
}
}

Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ class SchemaRDD(
seed: Long) =
new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))

/**
* :: Experimental ::
* Overriding base RDD implementation to leverage query optimizer
*/
@Experimental
override def count(): Long = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding javadoc for this? Just explain different from RDD count's, SchemaRDD count actually invokes the optimizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

groupBy()(Count(Literal(1))).collect().head.getLong(0)
}

/**
* :: Experimental ::
* Applies the given Generator, or table generating function, to this relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,7 @@ class DslQuerySuite extends QueryTest {
}

test("count") {
checkAnswer(
testData2.groupBy()(Count(1)),
testData2.count()
)
assert(testData2.count() === testData2.map(_ => 1).count())
}

test("null count") {
Expand All @@ -126,6 +123,10 @@ class DslQuerySuite extends QueryTest {
)
}

test("zero count") {
assert(testData4.count() === 0)
}

test("inner join where, one match per row") {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ object TestData {
(1, null) ::
(2, 2) :: Nil)

val testData4 = logical.LocalRelation('a.int, 'b.int)

case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(
Expand Down