1818
1919import sys
2020import types
21- import array
2221import itertools
2322import warnings
2423import decimal
2524import datetime
26- from operator import itemgetter
2725import keyword
2826import warnings
27+ from array import array
28+ from operator import itemgetter
2929
3030from pyspark .rdd import RDD , PipelinedRDD
3131from pyspark .serializers import BatchedSerializer , PickleSerializer
@@ -441,7 +441,7 @@ def _infer_type(obj):
441441 raise ValueError ("Can not infer type for empty dict" )
442442 key , value = obj .iteritems ().next ()
443443 return MapType (_infer_type (key ), _infer_type (value ), True )
444- elif isinstance (obj , (list , array . array )):
444+ elif isinstance (obj , (list , array )):
445445 if not obj :
446446 raise ValueError ("Can not infer type for empty list/array" )
447447 return ArrayType (_infer_type (obj [0 ]), True )
@@ -456,14 +456,20 @@ def _infer_schema(row):
456456 """Infer the schema from dict/namedtuple/object"""
457457 if isinstance (row , dict ):
458458 items = sorted (row .items ())
459+
459460 elif isinstance (row , tuple ):
460461 if hasattr (row , "_fields" ): # namedtuple
461462 items = zip (row ._fields , tuple (row ))
462- elif all (isinstance (x , tuple ) and len (x ) == 2
463- for x in row ):
463+ elif hasattr (row , "__FIELDS__" ): # Row
464+ items = zip (row .__FIELDS__ , tuple (row ))
465+ elif all (isinstance (x , tuple ) and len (x ) == 2 for x in row ):
464466 items = row
467+ else :
468+ raise ValueError ("Can't infer schema from tuple" )
469+
465470 elif hasattr (row , "__dict__" ): # object
466471 items = sorted (row .__dict__ .items ())
472+
467473 else :
468474 raise ValueError ("Can not infer schema for type: %s" % type (row ))
469475
@@ -494,9 +500,12 @@ def _create_converter(obj, dataType):
494500 elif isinstance (obj , tuple ):
495501 if hasattr (obj , "_fields" ): # namedtuple
496502 conv = tuple
497- elif all (isinstance (x , tuple ) and len (x ) == 2
498- for x in obj ):
503+ elif hasattr (obj , "__FIELDS__" ):
504+ conv = tuple
505+ elif all (isinstance (x , tuple ) and len (x ) == 2 for x in obj ):
499506 conv = lambda o : tuple (v for k , v in o )
507+ else :
508+ raise ValueError ("unexpected tuple" )
500509
501510 elif hasattr (obj , "__dict__" ): # object
502511 conv = lambda o : [o .__dict__ .get (n , None ) for n in names ]
@@ -783,6 +792,7 @@ class Row(tuple):
783792 """ Row in SchemaRDD """
784793 __DATATYPE__ = dataType
785794 __FIELDS__ = tuple (f .name for f in dataType .fields )
795+ __slots__ = ()
786796
787797 # create property for fast access
788798 locals ().update (_create_properties (dataType .fields ))
@@ -814,7 +824,7 @@ def __init__(self, sparkContext, sqlContext=None):
814824 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
815825 Traceback (most recent call last):
816826 ...
817- ValueError :...
827+ TypeError :...
818828
819829 >>> bad_rdd = sc.parallelize([1,2,3])
820830 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
@@ -823,9 +833,9 @@ def __init__(self, sparkContext, sqlContext=None):
823833 ValueError:...
824834
825835 >>> from datetime import datetime
826- >>> allTypes = sc.parallelize([{" int": 1, " string": "string",
827- ... "double": 1.0, " long": 1L, " boolean": True, " list": [1, 2, 3],
828- ... "time": datetime(2010, 1, 1, 1, 1, 1), " dict": {"a": 1},} ])
836+ >>> allTypes = sc.parallelize([Row( int= 1, string= "string",
837+ ... double= 1.0, long= 1L, boolean= True, list= [1, 2, 3],
838+ ... time= datetime(2010, 1, 1, 1, 1, 1), dict= {"a": 1}) ])
829839 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string,
830840 ... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list))
831841 >>> srdd.collect()[0]
@@ -851,33 +861,48 @@ def _ssql_ctx(self):
851861 return self ._scala_SQLContext
852862
853863 def inferSchema (self , rdd ):
854- """Infer and apply a schema to an RDD of L{dict}s.
864+ """Infer and apply a schema to an RDD of L{Row}s.
865+
866+ We peek at the first row of the RDD to determine the fields' names
867+ and types. Nested collections are supported, which include array,
868+ dict, list, Row, tuple, namedtuple, or object.
855869
856- We peek at the first row of the RDD to determine the fields names
857- and types, and then use that to extract all the dictionaries. Nested
858- collections are supported, which include array, dict, list, set, and
859- tuple.
870+ Each row in `rdd` should be Row object or namedtuple or objects,
871+ using dict is deprecated.
860872
873+ >>> rdd = sc.parallelize(
874+ ... [Row(field1=1, field2="row1"),
875+ ... Row(field1=2, field2="row2"),
876+ ... Row(field1=3, field2="row3")])
861877 >>> srdd = sqlCtx.inferSchema(rdd)
862878 >>> srdd.collect()[0]
863879 Row(field1=1, field2=u'row1')
864880
865- >>> from array import array
881+ >>> NestedRow = Row("f1", "f2")
882+ >>> nestedRdd1 = sc.parallelize([
883+ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
884+ ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
866885 >>> srdd = sqlCtx.inferSchema(nestedRdd1)
867886 >>> srdd.collect()
868887 [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
869888
889+ >>> nestedRdd2 = sc.parallelize([
890+ ... NestedRow([[1, 2], [2, 3]], [1, 2]),
891+ ... NestedRow([[2, 3], [3, 4]], [2, 3])])
870892 >>> srdd = sqlCtx.inferSchema(nestedRdd2)
871893 >>> srdd.collect()
872894 [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
873895 """
874- if (rdd .__class__ is SchemaRDD ):
875- raise ValueError ("Cannot apply schema to %s" % SchemaRDD .__name__ )
896+
897+ if isinstance (rdd , SchemaRDD ):
898+ raise TypeError ("Cannot apply schema to SchemaRDD" )
876899
877900 first = rdd .first ()
878901 if not first :
879902 raise ValueError ("The first row in RDD is empty, "
880903 "can not infer schema" )
904+ if type (first ) is dict :
905+ warnings .warn ("Using RDD of dict to inferSchema is deprecated" )
881906
882907 schema = _infer_schema (first )
883908 rdd = rdd .mapPartitions (lambda rows : _drop_schema (rows , schema ))
@@ -889,6 +914,7 @@ def applySchema(self, rdd, schema):
889914
890915 The schema should be a StructType.
891916
917+ >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
892918 >>> schema = StructType([StructField("field1", IntegerType(), False),
893919 ... StructField("field2", StringType(), False)])
894920 >>> srdd = sqlCtx.applySchema(rdd2, schema)
@@ -929,6 +955,9 @@ def applySchema(self, rdd, schema):
929955 [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
930956 """
931957
958+ if isinstance (rdd , SchemaRDD ):
959+ raise TypeError ("Cannot apply schema to SchemaRDD" )
960+
932961 first = rdd .first ()
933962 if not isinstance (first , (tuple , list )):
934963 raise ValueError ("Can not apply schema to type: %s" % type (first ))
@@ -1198,12 +1227,84 @@ def _get_hive_ctx(self):
11981227 return self ._jvm .TestHiveContext (self ._jsc .sc ())
11991228
12001229
1201- # a stub type, the real type is dynamic generated.
1230+ def _create_row (fields , values ):
1231+ row = Row (* values )
1232+ row .__FIELDS__ = fields
1233+ return row
1234+
1235+
12021236class Row (tuple ):
12031237 """
12041238 A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
1239+
1240+ Row can be used to create a row object by using named arguments,
1241+ the fields will be sorted by names.
1242+
1243+ >>> row = Row(name="Alice", age=11)
1244+ >>> row
1245+ Row(age=11, name='Alice')
1246+ >>> row.name, row.age
1247+ ('Alice', 11)
1248+
1249+ Row also can be used to create another Row like class, then it
1250+ could be used to create Row objects, such as
1251+
1252+ >>> Person = Row("name", "age")
1253+ >>> Person
1254+ <Row(name, age)>
1255+ >>> Person("Alice", 11)
1256+ Row(name='Alice', age=11)
12051257 """
12061258
1259+ def __new__ (self , * args , ** kwargs ):
1260+ if args and kwargs :
1261+ raise ValueError ("Can not use both args "
1262+ "and kwargs to create Row" )
1263+ if args :
1264+ # create row class or objects
1265+ return tuple .__new__ (self , args )
1266+
1267+ elif kwargs :
1268+ # create row objects
1269+ names = sorted (kwargs .keys ())
1270+ values = tuple (kwargs [n ] for n in names )
1271+ row = tuple .__new__ (self , values )
1272+ row .__FIELDS__ = names
1273+ return row
1274+
1275+ else :
1276+ raise ValueError ("No args or kwargs" )
1277+
1278+
1279+ # let obect acs like class
1280+ def __call__ (self , * args ):
1281+ """create new Row object"""
1282+ return _create_row (self , args )
1283+
1284+ def __getattr__ (self , item ):
1285+ if item .startswith ("__" ):
1286+ raise AttributeError (item )
1287+ try :
1288+ # it will be slow when it has many fields,
1289+ # but this will not be used in normal cases
1290+ idx = self .__FIELDS__ .index (item )
1291+ return self [idx ]
1292+ except IndexError :
1293+ raise AttributeError (item )
1294+
1295+ def __reduce__ (self ):
1296+ if hasattr (self , "__FIELDS__" ):
1297+ return (_create_row , (self .__FIELDS__ , tuple (self )))
1298+ else :
1299+ return tuple .__reduce__ (self )
1300+
1301+ def __repr__ (self ):
1302+ if hasattr (self , "__FIELDS__" ):
1303+ return "Row(%s)" % ", " .join ("%s=%r" % (k , v )
1304+ for k , v in zip (self .__FIELDS__ , self ))
1305+ else :
1306+ return "<Row(%s)>" % ", " .join (self )
1307+
12071308
12081309class SchemaRDD (RDD ):
12091310 """An RDD of L{Row} objects that has an associated schema.
@@ -1424,19 +1525,18 @@ def _test():
14241525 from pyspark .context import SparkContext
14251526 # let doctest run in pyspark.sql, so DataTypes can be picklable
14261527 import pyspark .sql
1427- from pyspark .sql import SQLContext
1528+ from pyspark .sql import Row , SQLContext
14281529 globs = pyspark .sql .__dict__ .copy ()
14291530 # The small batch size here ensures that we see multiple batches,
14301531 # even in these small test examples:
14311532 sc = SparkContext ('local[4]' , 'PythonTest' , batchSize = 2 )
14321533 globs ['sc' ] = sc
14331534 globs ['sqlCtx' ] = SQLContext (sc )
14341535 globs ['rdd' ] = sc .parallelize (
1435- [{ " field1" : 1 , " field2" : " row1"} ,
1436- { " field1" : 2 , " field2" : " row2"} ,
1437- { " field1" : 3 , " field2" : " row3"} ]
1536+ [Row ( field1 = 1 , field2 = " row1") ,
1537+ Row ( field1 = 2 , field2 = " row2") ,
1538+ Row ( field1 = 3 , field2 = " row3") ]
14381539 )
1439- globs ['rdd2' ] = sc .parallelize ([(1 , "row1" ), (2 , "row2" ), (3 , "row3" )])
14401540 jsonStrings = [
14411541 '{"field1": 1, "field2": "row1", "field3":{"field4":11}}' ,
14421542 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
@@ -1446,12 +1546,6 @@ def _test():
14461546 ]
14471547 globs ['jsonStrings' ] = jsonStrings
14481548 globs ['json' ] = sc .parallelize (jsonStrings )
1449- globs ['nestedRdd1' ] = sc .parallelize ([
1450- {"f1" : array ('i' , [1 , 2 ]), "f2" : {"row1" : 1.0 }},
1451- {"f1" : array ('i' , [2 , 3 ]), "f2" : {"row2" : 2.0 }}])
1452- globs ['nestedRdd2' ] = sc .parallelize ([
1453- {"f1" : [[1 , 2 ], [2 , 3 ]], "f2" : [1 , 2 ]},
1454- {"f1" : [[2 , 3 ], [3 , 4 ]], "f2" : [2 , 3 ]}])
14551549 (failure_count , test_count ) = doctest .testmod (
14561550 pyspark .sql , globs = globs , optionflags = doctest .ELLIPSIS )
14571551 globs ['sc' ].stop ()
0 commit comments