Skip to content

Commit e7a6c19

Browse files
committed
SchemaRDD.javaToPython should convert a field with the StructType to a Map.
1 parent 6d20b85 commit e7a6c19

File tree

3 files changed

+48
-17
lines changed

3 files changed

+48
-17
lines changed

docs/sql-programming-guide.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
238238

239239
# The results of SQL queries are RDDs and support all the normal RDD operations.
240240
teenNames = teenagers.map(lambda p: "Name: " + p.name)
241+
for teenName in teenNames.collect():
242+
print teenName
241243
{% endhighlight %}
242244

243245
</div>
@@ -275,7 +277,7 @@ val parquetFile = sqlCtx.parquetFile("people.parquet")
275277
//Parquet files can also be registered as tables and then used in SQL statements.
276278
parquetFile.registerAsTable("parquetFile")
277279
val teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
278-
teenagers.collect().foreach(println)
280+
teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
279281
{% endhighlight %}
280282

281283
</div>
@@ -311,10 +313,10 @@ List<String> teenagerNames = teenagers.map(new Function<Row, String>() {
311313
{% highlight python %}
312314
# sqlCtx from the previous example is used in this example.
313315

314-
peopleTable # The SchemaRDD from the previous example.
316+
schemaPeople # The SchemaRDD from the previous example.
315317

316318
# SchemaRDDs can be saved as Parquet files, maintaining the schema information.
317-
peopleTable.saveAsParquetFile("people.parquet")
319+
schemaPeople.saveAsParquetFile("people.parquet")
318320

319321
# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
320322
# The result of loading a parquet file is also a SchemaRDD.
@@ -324,6 +326,8 @@ parquetFile = sqlCtx.parquetFile("people.parquet")
324326
parquetFile.registerAsTable("parquetFile");
325327
teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
326328
teenNames = teenagers.map(lambda p: "Name: " + p.name)
329+
for teenName in teenNames.collect():
330+
print teenName
327331
{% endhighlight %}
328332

329333
</div>
@@ -477,11 +481,13 @@ people.printSchema()
477481
people.registerAsTable("people")
478482

479483
# SQL statements can be run by using the sql methods provided by sqlCtx.
480-
val teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
484+
teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
481485

482486
# The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
483487
# The columns of a row in the result can be accessed by ordinal.
484488
teenNames = teenagers.map(lambda p: "Name: " + p.name)
489+
for teenName in teenNames.collect():
490+
print teenName
485491

486492
# Alternatively, a SchemaRDD can be created for a JSON dataset represented by
487493
# a RDD[String] storing one JSON object per string.
@@ -597,6 +603,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the first ex
597603

598604
// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19'
599605
val teenagers = people.where('age >= 10).where('age <= 19).select('name)
606+
teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
600607
{% endhighlight %}
601608

602609
The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers

python/pyspark/sql.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ def jsonFile(self, path):
138138
>>> ofn.close()
139139
>>> srdd = sqlCtx.jsonFile(jsonFile)
140140
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
141-
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
142-
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
143-
... {"f1" : 3, "f2": "row3"}]
141+
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1")
142+
>>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}},
143+
... {"f1": 2, "f2": "row2", "f3":{"field4":22}},
144+
... {"f1": 3, "f2": "row3", "f3":{"field4":33}}]
144145
True
145146
"""
146147
jschema_rdd = self._ssql_ctx.jsonFile(path)
@@ -151,9 +152,10 @@ def jsonRDD(self, rdd):
151152
152153
>>> srdd = sqlCtx.jsonRDD(json)
153154
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
154-
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
155-
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
156-
... {"f1" : 3, "f2": "row3"}]
155+
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1")
156+
>>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}},
157+
... {"f1": 2, "f2": "row2", "f3":{"field4":22}},
158+
... {"f1": 3, "f2": "row3", "f3":{"field4":33}}]
157159
True
158160
"""
159161
def func(split, iterator):
@@ -369,7 +371,7 @@ def saveAsTable(self, tableName):
369371

370372
def getSchemaTreeString(self):
371373
"""Returns the output schema in the tree format."""
372-
self._jschema_rdd.getSchemaTreeString()
374+
return self._jschema_rdd.getSchemaTreeString()
373375

374376
def printSchema(self):
375377
"""Prints out the schema in the tree format."""
@@ -473,8 +475,9 @@ def _test():
473475
globs['sqlCtx'] = SQLContext(sc)
474476
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
475477
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
476-
jsonStrings = ['{"field1": 1, "field2": "row1"}',
477-
'{"field1" : 2, "field2": "row2"}', '{"field1" : 3, "field2": "row3"}']
478+
jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
479+
'{"field1" : 2, "field2": "row2", "field3":{"field4":22}}',
480+
'{"field1" : 3, "field2": "row3", "field3":{"field4":33}}']
478481
globs['jsonStrings'] = jsonStrings
479482
globs['json'] = sc.parallelize(jsonStrings)
480483
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis._
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
30-
import org.apache.spark.sql.catalyst.types.BooleanType
30+
import org.apache.spark.sql.catalyst.types.{DataType, StructType, BooleanType}
3131
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
3232
import org.apache.spark.api.java.JavaRDD
3333
import java.util.{Map => JMap}
@@ -344,13 +344,34 @@ class SchemaRDD(
344344
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
345345

346346
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
347-
val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name)
347+
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
348+
val fields = structType.fields.map(field => (field.name, field.dataType))
349+
val map: JMap[String, Any] = new java.util.HashMap
350+
row.zip(fields).foreach {
351+
case (obj, (name, dataType)) =>
352+
dataType match {
353+
case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct))
354+
case other => map.put(name, obj)
355+
}
356+
}
357+
358+
map
359+
}
360+
361+
// TODO: Actually, the schema of a row should be represented by a StructType instead of
362+
// a Seq[Attribute]. Once we have finished that change, we can just use rowToMap to
363+
// construct the Map for python.
364+
val fields: Seq[(String, DataType)] = this.queryExecution.analyzed.output.map(
365+
field => (field.name, field.dataType))
348366
this.mapPartitions { iter =>
349367
val pickle = new Pickler
350368
iter.map { row =>
351369
val map: JMap[String, Any] = new java.util.HashMap
352-
row.zip(fieldNames).foreach { case (obj, name) =>
353-
map.put(name, obj)
370+
row.zip(fields).foreach { case (obj, (name, dataType)) =>
371+
dataType match {
372+
case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct))
373+
case other => map.put(name, obj)
374+
}
354375
}
355376
map
356377
}.grouped(10).map(batched => pickle.dumps(batched.toArray))

0 commit comments

Comments
 (0)