@@ -1122,7 +1122,7 @@ def applySchema(self, rdd, schema):
11221122 batched = isinstance (rdd ._jrdd_deserializer , BatchedSerializer )
11231123 jrdd = self ._pythonToJava (rdd ._jrdd , batched )
11241124 srdd = self ._ssql_ctx .applySchemaToPythonRDD (jrdd .rdd (), str (schema ))
1125- return SchemaRDD (srdd , self )
1125+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
11261126
11271127 def registerRDDAsTable (self , rdd , tableName ):
11281128 """Registers the given RDD as a temporary table in the catalog.
@@ -1134,8 +1134,8 @@ def registerRDDAsTable(self, rdd, tableName):
11341134 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
11351135 """
11361136 if (rdd .__class__ is SchemaRDD ):
1137- jschema_rdd = rdd ._jschema_rdd
1138- self ._ssql_ctx .registerRDDAsTable (jschema_rdd , tableName )
1137+ srdd = rdd ._jschema_rdd . baseSchemaRDD ()
1138+ self ._ssql_ctx .registerRDDAsTable (srdd , tableName )
11391139 else :
11401140 raise ValueError ("Can only register SchemaRDD as table" )
11411141
@@ -1151,7 +1151,7 @@ def parquetFile(self, path):
11511151 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
11521152 True
11531153 """
1154- jschema_rdd = self ._ssql_ctx .parquetFile (path )
1154+ jschema_rdd = self ._ssql_ctx .parquetFile (path ). toJavaSchemaRDD ()
11551155 return SchemaRDD (jschema_rdd , self )
11561156
11571157 def jsonFile (self , path , schema = None ):
@@ -1207,11 +1207,11 @@ def jsonFile(self, path, schema=None):
12071207 [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
12081208 """
12091209 if schema is None :
1210- jschema_rdd = self ._ssql_ctx .jsonFile (path )
1210+ srdd = self ._ssql_ctx .jsonFile (path )
12111211 else :
12121212 scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1213- jschema_rdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1214- return SchemaRDD (jschema_rdd , self )
1213+ srdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1214+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
12151215
12161216 def jsonRDD (self , rdd , schema = None ):
12171217 """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1275,11 +1275,11 @@ def func(iterator):
12751275 keyed ._bypass_serializer = True
12761276 jrdd = keyed ._jrdd .map (self ._jvm .BytesToString ())
12771277 if schema is None :
1278- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
1278+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
12791279 else :
12801280 scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1281- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1282- return SchemaRDD (jschema_rdd , self )
1281+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1282+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
12831283
12841284 def sql (self , sqlQuery ):
12851285 """Return a L{SchemaRDD} representing the result of the given query.
@@ -1290,7 +1290,7 @@ def sql(self, sqlQuery):
12901290 >>> srdd2.collect()
12911291 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
12921292 """
1293- return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ), self )
1293+ return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ). toJavaSchemaRDD () , self )
12941294
12951295 def table (self , tableName ):
12961296 """Returns the specified table as a L{SchemaRDD}.
@@ -1301,7 +1301,7 @@ def table(self, tableName):
13011301 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
13021302 True
13031303 """
1304- return SchemaRDD (self ._ssql_ctx .table (tableName ), self )
1304+ return SchemaRDD (self ._ssql_ctx .table (tableName ). toJavaSchemaRDD () , self )
13051305
13061306 def cacheTable (self , tableName ):
13071307 """Caches the specified table in-memory."""
@@ -1353,7 +1353,7 @@ def hiveql(self, hqlQuery):
13531353 warnings .warn ("hiveql() is deprecated as the sql function now parses using HiveQL by" +
13541354 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'" ,
13551355 DeprecationWarning )
1356- return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ), self )
1356+ return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ). toJavaSchemaRDD () , self )
13571357
13581358 def hql (self , hqlQuery ):
13591359 """
@@ -1524,6 +1524,8 @@ class SchemaRDD(RDD):
15241524 def __init__ (self , jschema_rdd , sql_ctx ):
15251525 self .sql_ctx = sql_ctx
15261526 self ._sc = sql_ctx ._sc
1527+ clsName = jschema_rdd .getClass ().getName ()
1528+ assert clsName .endswith ("JavaSchemaRDD" ), "jschema_rdd must be JavaSchemaRDD"
15271529 self ._jschema_rdd = jschema_rdd
15281530 self ._id = None
15291531 self .is_cached = False
@@ -1540,7 +1542,7 @@ def _jrdd(self):
15401542 L{pyspark.rdd.RDD} super class (map, filter, etc.).
15411543 """
15421544 if not hasattr (self , '_lazy_jrdd' ):
1543- self ._lazy_jrdd = self ._jschema_rdd .javaToPython ()
1545+ self ._lazy_jrdd = self ._jschema_rdd .baseSchemaRDD (). javaToPython ()
15441546 return self ._lazy_jrdd
15451547
15461548 def id (self ):
@@ -1598,7 +1600,7 @@ def saveAsTable(self, tableName):
15981600 def schema (self ):
15991601 """Returns the schema of this SchemaRDD (represented by
16001602 a L{StructType})."""
1601- return _parse_datatype_string (self ._jschema_rdd .schema ().toString ())
1603+ return _parse_datatype_string (self ._jschema_rdd .baseSchemaRDD (). schema ().toString ())
16021604
16031605 def schemaString (self ):
16041606 """Returns the output schema in the tree format."""
@@ -1649,8 +1651,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
16491651 rdd = RDD (self ._jrdd , self ._sc , self ._jrdd_deserializer )
16501652
16511653 schema = self .schema ()
1652- import pickle
1653- pickle .loads (pickle .dumps (schema ))
16541654
16551655 def applySchema (_ , it ):
16561656 cls = _create_cls (schema )
@@ -1687,10 +1687,8 @@ def isCheckpointed(self):
16871687
16881688 def getCheckpointFile (self ):
16891689 checkpointFile = self ._jschema_rdd .getCheckpointFile ()
1690- if checkpointFile .isDefined ():
1690+ if checkpointFile .isPresent ():
16911691 return checkpointFile .get ()
1692- else :
1693- return None
16941692
16951693 def coalesce (self , numPartitions , shuffle = False ):
16961694 rdd = self ._jschema_rdd .coalesce (numPartitions , shuffle )
0 commit comments