Skip to content

Commit 40491c9

Browse files
committed
PR Changes + Method Visibility
1 parent 1836944 commit 40491c9

File tree

5 files changed

+35
-37
lines changed

5 files changed

+35
-37
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,39 +286,33 @@ private[spark] object PythonRDD {
286286
file.close()
287287
}
288288

289-
def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[_] = {
290-
pyRDD.rdd.mapPartitions { iter =>
291-
val unpickle = new Unpickler
292-
// TODO: Figure out why flatMap is necessay for pyspark
293-
iter.flatMap { row =>
294-
unpickle.loads(row) match {
295-
case objs: java.util.ArrayList[Any] => objs
296-
// Incase the partition doesn't have a collection
297-
case obj => Seq(obj)
298-
}
299-
}
300-
}
301-
}
302-
289+
/**
290+
* Convert an RDD of serialized Python dictionaries to Scala Maps
291+
* TODO: Support more Python types.
292+
*/
303293
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
304294
pyRDD.rdd.mapPartitions { iter =>
305295
val unpickle = new Unpickler
306296
// TODO: Figure out why flatMap is necessay for pyspark
307297
iter.flatMap { row =>
308298
unpickle.loads(row) match {
309-
case objs: java.util.ArrayList[JMap[String, _]] => objs.map(_.toMap)
299+
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
310300
// Incase the partition doesn't have a collection
311-
case obj: JMap[String, _] => Seq(obj.toMap)
301+
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
312302
}
313303
}
314304
}
315305
}
316306

307+
/**
308+
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
309+
* PySpark.
310+
*/
317311
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
318312
jRDD.rdd.mapPartitions { iter =>
319-
val unpickle = new Pickler
313+
val pickle = new Pickler
320314
iter.map { row =>
321-
unpickle.dumps(row)
315+
pickle.dumps(row)
322316
}
323317
}
324318
}

python/pyspark/context.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
174174
SparkContext._gateway = gateway or launch_gateway()
175175
SparkContext._jvm = SparkContext._gateway.jvm
176176
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
177-
SparkContext._pythonToJava = SparkContext._jvm.PythonRDD.pythonToJava
178177
SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap
179178
SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython
180179

@@ -481,21 +480,21 @@ def __init__(self, sparkContext):
481480
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
482481
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
483482
484-
>>> srdd = sqlCtx.applySchema(rdd)
485-
>>> sqlCtx.applySchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
483+
>>> srdd = sqlCtx.inferSchema(rdd)
484+
>>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
486485
Traceback (most recent call last):
487486
...
488487
ValueError:...
489488
490489
>>> bad_rdd = sc.parallelize([1,2,3])
491-
>>> sqlCtx.applySchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
490+
>>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
492491
Traceback (most recent call last):
493492
...
494493
ValueError:...
495494
496495
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
497496
... "boolean" : True}])
498-
>>> srdd = sqlCtx.applySchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
497+
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
499498
... x.boolean))
500499
>>> srdd.collect()[0]
501500
(1, u'string', 1.0, 1, True)
@@ -514,7 +513,7 @@ def _ssql_ctx(self):
514513
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
515514
return self._scala_SQLContext
516515

517-
def applySchema(self, rdd):
516+
def inferSchema(self, rdd):
518517
"""
519518
Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
520519
determine the fields names and types, and then use that to extract all the dictionaries.
@@ -523,7 +522,7 @@ def applySchema(self, rdd):
523522
>>> sqlCtx = SQLContext(sc)
524523
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
525524
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
526-
>>> srdd = sqlCtx.applySchema(rdd)
525+
>>> srdd = sqlCtx.inferSchema(rdd)
527526
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
528527
... {"field1" : 3, "field2": "row3"}]
529528
True
@@ -535,7 +534,7 @@ def applySchema(self, rdd):
535534
(SchemaRDD.__name__, rdd.first()))
536535

537536
jrdd = self._sc._pythonToJavaMap(rdd._jrdd)
538-
srdd = self._ssql_ctx.applySchema(jrdd.rdd())
537+
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
539538
return SchemaRDD(srdd, self)
540539

541540
def registerRDDAsTable(self, rdd, tableName):
@@ -546,7 +545,7 @@ def registerRDDAsTable(self, rdd, tableName):
546545
>>> sqlCtx = SQLContext(sc)
547546
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
548547
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
549-
>>> srdd = sqlCtx.applySchema(rdd)
548+
>>> srdd = sqlCtx.inferSchema(rdd)
550549
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
551550
"""
552551
if (rdd.__class__ is SchemaRDD):
@@ -563,7 +562,7 @@ def parquetFile(self, path):
563562
>>> sqlCtx = SQLContext(sc)
564563
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
565564
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
566-
>>> srdd = sqlCtx.applySchema(rdd)
565+
>>> srdd = sqlCtx.inferSchema(rdd)
567566
>>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
568567
>>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
569568
>>> srdd.collect() == srdd2.collect()
@@ -580,7 +579,7 @@ def sql(self, sqlQuery):
580579
>>> sqlCtx = SQLContext(sc)
581580
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
582581
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
583-
>>> srdd = sqlCtx.applySchema(rdd)
582+
>>> srdd = sqlCtx.inferSchema(rdd)
584583
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
585584
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
586585
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
@@ -596,7 +595,7 @@ def table(self, tableName):
596595
>>> sqlCtx = SQLContext(sc)
597596
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
598597
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
599-
>>> srdd = sqlCtx.applySchema(rdd)
598+
>>> srdd = sqlCtx.inferSchema(rdd)
600599
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
601600
>>> srdd2 = sqlCtx.table("table1")
602601
>>> srdd.collect() == srdd2.collect()

python/pyspark/rdd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,7 @@ def saveAsParquetFile(self, path):
14451445
>>> sqlCtx = SQLContext(sc)
14461446
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
14471447
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
1448-
>>> srdd = sqlCtx.applySchema(rdd)
1448+
>>> srdd = sqlCtx.inferSchema(rdd)
14491449
>>> srdd.saveAsParquetFile("/tmp/test.parquet")
14501450
>>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
14511451
>>> srdd2.collect() == srdd.collect()
@@ -1461,7 +1461,7 @@ def registerAsTable(self, name):
14611461
>>> sqlCtx = SQLContext(sc)
14621462
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
14631463
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
1464-
>>> srdd = sqlCtx.applySchema(rdd)
1464+
>>> srdd = sqlCtx.inferSchema(rdd)
14651465
>>> srdd.registerAsTable("test")
14661466
>>> srdd2 = sqlCtx.sql("select * from test")
14671467
>>> srdd.collect() == srdd2.collect()

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
243243
def debugExec() = DebugQuery(executedPlan).execute().collect()
244244
}
245245

246-
// TODO: We only support primitive types, add support for nested types. Difficult because java
247-
// objects don't have classTags
248-
def applySchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
246+
/**
247+
* Peek at the first row of the RDD and infer its schema.
248+
* TODO: We only support primitive types, add support for nested types.
249+
*/
250+
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
249251
val schema = rdd.first.map { case (fieldName, obj) =>
250252
val dataType = obj.getClass match {
251253
case c: Class[_] if c == classOf[java.lang.String] => StringType

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20-
import net.razorvine.pickle.{Pickler, Unpickler}
20+
import net.razorvine.pickle.Pickler
2121

2222
import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
2323
import org.apache.spark.annotation.{AlphaComponent, Experimental}
@@ -313,12 +313,15 @@ class SchemaRDD(
313313
/** FOR INTERNAL USE ONLY */
314314
def analyze = sqlContext.analyzer(logicalPlan)
315315

316-
def javaToPython: JavaRDD[Array[Byte]] = {
316+
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
317317
val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name)
318318
this.mapPartitions { iter =>
319319
val pickle = new Pickler
320320
iter.map { row =>
321321
val map: JMap[String, Any] = new java.util.HashMap
322+
// TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict].
323+
// Ideally we should be able to pickle an object directly into a Python collection so we
324+
// don't have to create an ArrayList every time.
322325
val arr: java.util.ArrayList[Any] = new java.util.ArrayList
323326
row.zip(fieldNames).foreach { case (obj, name) =>
324327
map.put(name, obj)

0 commit comments

Comments
 (0)