Skip to content

Commit 9be0110

Browse files
committed
Added tests and cleanup
1 parent 4eaef85 commit 9be0110

File tree

4 files changed

+98
-85
lines changed

4 files changed

+98
-85
lines changed

python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import unittest
1919

2020
from pyspark.rdd import PythonEvalType
21+
from pyspark.sql import Row
2122
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
2223
udf, pandas_udf, PandasUDFType
2324
from pyspark.sql.types import *
@@ -461,6 +462,18 @@ def test_register_vectorized_udf_basic(self):
461462
expected = [1, 5]
462463
self.assertEqual(actual, expected)
463464

465+
def test_grouped_with_empty_partition(self):
466+
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
467+
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
468+
num_parts = len(data) + 1
469+
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
470+
471+
f = pandas_udf(lambda x: x.sum(),
472+
'int', PandasUDFType.GROUPED_AGG)
473+
474+
result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
475+
self.assertEqual(result, expected)
476+
464477

465478
if __name__ == "__main__":
466479
from pyspark.sql.tests.test_pandas_udf_grouped_agg import *

python/pyspark/sql/tests/test_pandas_udf_grouped_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,18 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
504504

505505
self.assertEquals(result.collect()[0]['sum'], 165)
506506

507+
def test_grouped_with_empty_partition(self):
508+
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
509+
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
510+
num_parts = len(data) + 1
511+
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
512+
513+
f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()),
514+
'id long, x int', PandasUDFType.GROUPED_MAP)
515+
516+
result = df.groupBy('id').apply(f).collect()
517+
self.assertEqual(result, expected)
518+
507519

508520
if __name__ == "__main__":
509521
from pyspark.sql.tests.test_pandas_udf_grouped_map import *

sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -105,58 +105,53 @@ case class AggregateInPandasExec(
105105
StructField(s"_$i", dt)
106106
})
107107

108-
inputRDD.mapPartitionsInternal { iter =>
108+
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
109+
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
110+
val prunedProj = UnsafeProjection.create(allInputs, child.output)
109111

110-
// Only execute on non-empty partitions
111-
if (iter.nonEmpty) {
112-
val prunedProj = UnsafeProjection.create(allInputs, child.output)
113-
114-
val grouped = if (groupingExpressions.isEmpty) {
115-
// Use an empty unsafe row as a place holder for the grouping key
116-
Iterator((new UnsafeRow(), iter))
117-
} else {
118-
GroupedIterator(iter, groupingExpressions, child.output)
119-
}.map { case (key, rows) =>
120-
(key, rows.map(prunedProj))
121-
}
112+
val grouped = if (groupingExpressions.isEmpty) {
113+
// Use an empty unsafe row as a place holder for the grouping key
114+
Iterator((new UnsafeRow(), iter))
115+
} else {
116+
GroupedIterator(iter, groupingExpressions, child.output)
117+
}.map { case (key, rows) =>
118+
(key, rows.map(prunedProj))
119+
}
122120

123-
val context = TaskContext.get()
121+
val context = TaskContext.get()
124122

125-
// The queue used to buffer input rows so we can drain it to
126-
// combine input with output from Python.
127-
val queue = HybridRowQueue(context.taskMemoryManager(),
128-
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
129-
context.addTaskCompletionListener[Unit] { _ =>
130-
queue.close()
131-
}
123+
// The queue used to buffer input rows so we can drain it to
124+
// combine input with output from Python.
125+
val queue = HybridRowQueue(context.taskMemoryManager(),
126+
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
127+
context.addTaskCompletionListener[Unit] { _ =>
128+
queue.close()
129+
}
132130

133-
// Add rows to queue to join later with the result.
134-
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
135-
queue.add(groupingKey.asInstanceOf[UnsafeRow])
136-
rows
137-
}
131+
// Add rows to queue to join later with the result.
132+
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
133+
queue.add(groupingKey.asInstanceOf[UnsafeRow])
134+
rows
135+
}
138136

139-
val columnarBatchIter = new ArrowPythonRunner(
140-
pyFuncs,
141-
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
142-
argOffsets,
143-
aggInputSchema,
144-
sessionLocalTimeZone,
145-
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
146-
147-
val joinedAttributes =
148-
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
149-
val joined = new JoinedRow
150-
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)
151-
152-
columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
153-
val leftRow = queue.remove()
154-
val joinedRow = joined(leftRow, aggOutputRow)
155-
resultProj(joinedRow)
156-
}
157-
} else {
158-
Iterator.empty
137+
val columnarBatchIter = new ArrowPythonRunner(
138+
pyFuncs,
139+
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
140+
argOffsets,
141+
aggInputSchema,
142+
sessionLocalTimeZone,
143+
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
144+
145+
val joinedAttributes =
146+
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
147+
val joined = new JoinedRow
148+
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)
149+
150+
columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
151+
val leftRow = queue.remove()
152+
val joinedRow = joined(leftRow, aggOutputRow)
153+
resultProj(joinedRow)
159154
}
160-
}
155+
}}
161156
}
162157
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -125,45 +125,38 @@ case class FlatMapGroupsInPandasExec(
125125
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
126126
val dedupSchema = StructType.fromAttributes(dedupAttributes)
127127

128-
inputRDD.mapPartitionsInternal { iter =>
129-
130-
// Only execute on non-empty partitions
131-
if (iter.nonEmpty) {
132-
133-
val grouped = if (groupingAttributes.isEmpty) {
134-
Iterator(iter)
135-
} else {
136-
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
137-
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
138-
groupedIter.map {
139-
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
140-
}
141-
}
142-
143-
val context = TaskContext.get()
144-
145-
val columnarBatchIter = new ArrowPythonRunner(
146-
chainedFunc,
147-
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
148-
argOffsets,
149-
dedupSchema,
150-
sessionLocalTimeZone,
151-
pythonRunnerConf).compute(grouped, context.partitionId(), context)
152-
153-
val unsafeProj = UnsafeProjection.create(output, output)
154-
155-
columnarBatchIter.flatMap { batch =>
156-
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
157-
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
158-
val outputVectors = output.indices.map(structVector.getChild)
159-
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
160-
flattenedBatch.setNumRows(batch.numRows())
161-
flattenedBatch.rowIterator.asScala
162-
}.map(unsafeProj)
163-
128+
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
129+
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
130+
val grouped = if (groupingAttributes.isEmpty) {
131+
Iterator(iter)
164132
} else {
165-
Iterator.empty
133+
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
134+
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
135+
groupedIter.map {
136+
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
137+
}
166138
}
167-
}
139+
140+
val context = TaskContext.get()
141+
142+
val columnarBatchIter = new ArrowPythonRunner(
143+
chainedFunc,
144+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
145+
argOffsets,
146+
dedupSchema,
147+
sessionLocalTimeZone,
148+
pythonRunnerConf).compute(grouped, context.partitionId(), context)
149+
150+
val unsafeProj = UnsafeProjection.create(output, output)
151+
152+
columnarBatchIter.flatMap { batch =>
153+
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
154+
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
155+
val outputVectors = output.indices.map(structVector.getChild)
156+
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
157+
flattenedBatch.setNumRows(batch.numRows())
158+
flattenedBatch.rowIterator.asScala
159+
}.map(unsafeProj)
160+
}}
168161
}
169162
}

0 commit comments

Comments
 (0)