Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,7 +1973,7 @@ def collect(self):
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.collectToPython().iterator()
bytesInJava = self._jdf.javaToPython().collect().iterator()
cls = _create_cls(self.schema())
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
Expand All @@ -1997,14 +1997,14 @@ def take(self, num):
return self.limit(num).collect()

def map(self, f):
""" Return a new RDD by applying a function to each Row, it's a
shorthand for df.rdd.map()
"""
return self.rdd.map(f)

# Convert each object in the RDD to a Row with the right class
# for this DataFrame, so that fields can be accessed as attributes.
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition of this RDD,
while tracking the index of the original partition.
Return a new RDD by applying a function to each partition.

>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
Expand All @@ -2013,21 +2013,28 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
return self.rdd.mapPartitions(f, preservesPartitioning)

# We override the default cache/persist/checkpoint behavior
# as we want to cache the underlying DataFrame object in the JVM,
# not the PythonRDD checkpointed by the super class
def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
self._jdf.cache()
return self

def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
""" Set the storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the RDD does not have a storage level set yet.
If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
self._jdf.persist(javaStorageLevel)
return self

def unpersist(self, blocking=True):
""" Mark it as non-persistent, and remove all blocks for it from
memory and disk.
"""
self.is_cached = False
self._jdf.unpersist(blocking)
return self
Expand All @@ -2036,10 +2043,12 @@ def unpersist(self, blocking=True):
# rdd = self._jdf.coalesce(numPartitions, shuffle, None)
# return DataFrame(rdd, self.sql_ctx)

# def repartition(self, numPartitions):
# rdd = self._jdf.repartition(numPartitions, None)
# return DataFrame(rdd, self.sql_ctx)
#
def repartition(self, numPartitions):
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
"""
rdd = self._jdf.repartition(numPartitions, None)
return DataFrame(rdd, self.sql_ctx)

def sample(self, withReplacement, fraction, seed=None):
"""
Expand Down Expand Up @@ -2359,11 +2368,11 @@ def _scalaMethod(name):
""" Translate operators into methodName in Scala

For example:
>>> scalaMethod('+')
>>> _scalaMethod('+')
'$plus'
>>> scalaMethod('>=')
>>> _scalaMethod('>=')
'$greater$eq'
>>> scalaMethod('cast')
>>> _scalaMethod('cast')
'cast'
"""
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,7 @@ def test_apply_schema_with_udt(self):
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df = self.sqlCtx.applySchema(rdd, schema)
# TODO: test collect with UDT
point = df.rdd.first().point
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))

def test_parquet_with_udt(self):
Expand Down Expand Up @@ -984,11 +983,12 @@ def test_column_select(self):
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())

def test_aggregator(self):
from pyspark.sql import Aggregator as Agg
df = self.df
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
# TODO(davies): fix aggregators
from pyspark.sql import Aggregator as Agg
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))


Expand Down
12 changes: 1 addition & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,7 @@ class DataFrame protected[sql](
*/
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
SerDeUtil.javaToPython(jrdd)
}
/**
* Serializes the Array[Row] returned by collect(), using the same format as javaToPython.
*/
protected[sql] def collectToPython: JList[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val pickle = new Pickler
new ArrayList[Array[Byte]](collect().map { row =>
EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}
}