Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 75 additions & 78 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import datetime
import keyword
import warnings
import json
from array import array
from operator import itemgetter

Expand Down Expand Up @@ -62,6 +63,18 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def typeName(cls):
return cls.__name__[:-4].lower()

def jsonValue(self):
return self.typeName()

def json(self):
return json.dumps(self.jsonValue(),
separators=(',', ':'),
sort_keys=True)


class PrimitiveTypeSingleton(type):

Expand Down Expand Up @@ -205,6 +218,16 @@ def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"elementType": self.elementType.jsonValue(),
"containsNull": self.containsNull}

@classmethod
def fromJson(cls, json):
return ArrayType(_parse_datatype_json_value(json["elementType"]),
json["containsNull"])


class MapType(DataType):

Expand Down Expand Up @@ -245,6 +268,18 @@ def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"keyType": self.keyType.jsonValue(),
"valueType": self.valueType.jsonValue(),
"valueContainsNull": self.valueContainsNull}

@classmethod
def fromJson(cls, json):
return MapType(_parse_datatype_json_value(json["keyType"]),
_parse_datatype_json_value(json["valueType"]),
json["valueContainsNull"])


class StructField(DataType):

Expand Down Expand Up @@ -283,6 +318,17 @@ def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())

def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
"nullable": self.nullable}

@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
json["nullable"])


class StructType(DataType):

Expand Down Expand Up @@ -312,42 +358,30 @@ def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))

def jsonValue(self):
return {"type": self.typeName(),
"fields": [f.jsonValue() for f in self.fields]}

def _parse_datatype_list(datatype_list_string):
"""Parses a list of comma separated data types."""
index = 0
datatype_list = []
start = 0
depth = 0
while index < len(datatype_list_string):
if depth == 0 and datatype_list_string[index] == ",":
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
start = index + 1
elif datatype_list_string[index] == "(":
depth += 1
elif datatype_list_string[index] == ")":
depth -= 1
@classmethod
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])

index += 1

# Handle the last data type
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
return datatype_list
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
v.__base__ == PrimitiveType)


_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])


def _parse_datatype_string(datatype_string):
"""Parses the given data type string.

def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
... python_datatype = _parse_datatype_string(
... scala_datatype.toString())
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
Expand Down Expand Up @@ -385,51 +419,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
index = datatype_string.find("(")
if index == -1:
# It is a primitive type.
index = len(datatype_string)
type_or_field = datatype_string[:index]
rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()

if type_or_field in _all_primitive_types:
return _all_primitive_types[type_or_field]()

elif type_or_field == "ArrayType":
last_comma_index = rest_part.rfind(",")
containsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
containsNull = False
elementType = _parse_datatype_string(
rest_part[:last_comma_index].strip())
return ArrayType(elementType, containsNull)

elif type_or_field == "MapType":
last_comma_index = rest_part.rfind(",")
valueContainsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
valueContainsNull = False
keyType, valueType = _parse_datatype_list(
rest_part[:last_comma_index].strip())
return MapType(keyType, valueType, valueContainsNull)

elif type_or_field == "StructField":
first_comma_index = rest_part.find(",")
name = rest_part[:first_comma_index].strip()
last_comma_index = rest_part.rfind(",")
nullable = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
nullable = False
dataType = _parse_datatype_string(
rest_part[first_comma_index + 1:last_comma_index].strip())
return StructField(name, dataType, nullable)

elif type_or_field == "StructType":
# rest_part should be in the format like
# List(StructField(field1,IntegerType,false)).
field_list_string = rest_part[rest_part.find("(") + 1:-1]
fields = _parse_datatype_list(field_list_string)
return StructType(fields)
return _parse_datatype_json_value(json.loads(json_string))


def _parse_datatype_json_value(json_value):
if type(json_value) is unicode and json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)


# Mapping Python types to Spark SQL DateType
Expand Down Expand Up @@ -983,7 +980,7 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
str(returnType))
returnType.json())

def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
Expand Down Expand Up @@ -1119,7 +1116,7 @@ def applySchema(self, rdd, schema):

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def registerRDDAsTable(self, rdd, tableName):
Expand Down Expand Up @@ -1209,7 +1206,7 @@ def jsonFile(self, path, schema=None):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1279,7 +1276,7 @@ def func(iterator):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1614,7 +1611,7 @@ def saveAsTable(self, tableName):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())

def schemaString(self):
"""Returns the output schema in the tree format."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
case object DynamicType extends DataType {
def simpleString: String = "dynamic"
}
case object DynamicType extends DataType

/**
* Wrap a [[Row]] as a [[DynamicRow]].
Expand Down
Loading