Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 75 additions & 78 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import datetime
import keyword
import warnings
import json
from array import array
from operator import itemgetter
from itertools import imap
Expand Down Expand Up @@ -71,6 +72,18 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def typeName(cls):
return cls.__name__[:-4].lower()

def jsonValue(self):
return self.typeName()

def json(self):
return json.dumps(self.jsonValue(),
separators=(',', ':'),
sort_keys=True)


class PrimitiveTypeSingleton(type):

Expand Down Expand Up @@ -214,6 +227,16 @@ def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"elementType": self.elementType.jsonValue(),
"containsNull": self.containsNull}

@classmethod
def fromJson(cls, json):
return ArrayType(_parse_datatype_json_value(json["elementType"]),
json["containsNull"])


class MapType(DataType):

Expand Down Expand Up @@ -254,6 +277,18 @@ def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"keyType": self.keyType.jsonValue(),
"valueType": self.valueType.jsonValue(),
"valueContainsNull": self.valueContainsNull}

@classmethod
def fromJson(cls, json):
return MapType(_parse_datatype_json_value(json["keyType"]),
_parse_datatype_json_value(json["valueType"]),
json["valueContainsNull"])


class StructField(DataType):

Expand Down Expand Up @@ -292,6 +327,17 @@ def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())

def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
"nullable": self.nullable}

@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
json["nullable"])


class StructType(DataType):

Expand Down Expand Up @@ -321,42 +367,30 @@ def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))

def jsonValue(self):
return {"type": self.typeName(),
"fields": [f.jsonValue() for f in self.fields]}

def _parse_datatype_list(datatype_list_string):
"""Parses a list of comma separated data types."""
index = 0
datatype_list = []
start = 0
depth = 0
while index < len(datatype_list_string):
if depth == 0 and datatype_list_string[index] == ",":
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
start = index + 1
elif datatype_list_string[index] == "(":
depth += 1
elif datatype_list_string[index] == ")":
depth -= 1
@classmethod
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])

index += 1

# Handle the last data type
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
return datatype_list
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
v.__base__ == PrimitiveType)


_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])


def _parse_datatype_string(datatype_string):
"""Parses the given data type string.

def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
... python_datatype = _parse_datatype_string(
... scala_datatype.toString())
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
Expand Down Expand Up @@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
index = datatype_string.find("(")
if index == -1:
# It is a primitive type.
index = len(datatype_string)
type_or_field = datatype_string[:index]
rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()

if type_or_field in _all_primitive_types:
return _all_primitive_types[type_or_field]()

elif type_or_field == "ArrayType":
last_comma_index = rest_part.rfind(",")
containsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
containsNull = False
elementType = _parse_datatype_string(
rest_part[:last_comma_index].strip())
return ArrayType(elementType, containsNull)

elif type_or_field == "MapType":
last_comma_index = rest_part.rfind(",")
valueContainsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
valueContainsNull = False
keyType, valueType = _parse_datatype_list(
rest_part[:last_comma_index].strip())
return MapType(keyType, valueType, valueContainsNull)

elif type_or_field == "StructField":
first_comma_index = rest_part.find(",")
name = rest_part[:first_comma_index].strip()
last_comma_index = rest_part.rfind(",")
nullable = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
nullable = False
dataType = _parse_datatype_string(
rest_part[first_comma_index + 1:last_comma_index].strip())
return StructField(name, dataType, nullable)

elif type_or_field == "StructType":
# rest_part should be in the format like
# List(StructField(field1,IntegerType,false)).
field_list_string = rest_part[rest_part.find("(") + 1:-1]
fields = _parse_datatype_list(field_list_string)
return StructType(fields)
return _parse_datatype_json_value(json.loads(json_string))


def _parse_datatype_json_value(json_value):
if type(json_value) is unicode and json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)


# Mapping Python types to Spark SQL DateType
Expand Down Expand Up @@ -992,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
str(returnType))
returnType.json())

def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
Expand Down Expand Up @@ -1128,7 +1125,7 @@ def applySchema(self, rdd, schema):

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def registerRDDAsTable(self, rdd, tableName):
Expand Down Expand Up @@ -1218,7 +1215,7 @@ def jsonFile(self, path, schema=None):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1288,7 +1285,7 @@ def func(iterator):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1623,7 +1620,7 @@ def saveAsTable(self, tableName):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())

def schemaString(self):
"""Returns the output schema in the tree format."""
Expand Down
31 changes: 15 additions & 16 deletions sql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,37 @@ Type in expressions to have them evaluated.
Type :help for more information.

scala> val query = sql("SELECT * FROM (SELECT * FROM src) a")
query: org.apache.spark.sql.ExecutedQuery =
SELECT * FROM (SELECT * FROM src) a
=== Query Plan ===
Project [key#6:0.0,value#7:0.1]
HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None
query: org.apache.spark.sql.SchemaRDD =
== Query Plan ==
== Physical Plan ==
HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None
```

Query results are RDDs and can be operated as such.
```
scala> query.collect()
res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]...
res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]...
```

You can also build further queries on top of these RDDs using the query DSL.
```
scala> query.where('key === 100).toRdd.collect()
res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100])
scala> query.where('key === 100).collect()
res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100])
```

From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects.
From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects.
```scala
scala> query.logicalPlan
res1: catalyst.plans.logical.LogicalPlan =
Project {key#0,value#1}
Project {key#0,value#1}
scala> query.queryExecution.analyzed
res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Project [key#10,value#11]
Project [key#10,value#11]
MetastoreRelation default, src, None


scala> query.logicalPlan transform {
scala> query.queryExecution.analyzed transform {
| case Project(projectList, child) if projectList == child.output => child
| }
res2: catalyst.plans.logical.LogicalPlan =
Project {key#0,value#1}
res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Project [key#10,value#11]
MetastoreRelation default, src, None
```
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,11 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType == DecimalType => d
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType == DecimalType => d

case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
case object DynamicType extends DataType {
def simpleString: String = "dynamic"
}
case object DynamicType extends DataType

/**
* Wrap a [[Row]] as a [[DynamicRow]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,18 @@ object BooleanSimplification extends Rule[LogicalPlan] {
case (_, _) => or
}

case not @ Not(exp) =>
exp match {
case Literal(true, BooleanType) => Literal(false)
case Literal(false, BooleanType) => Literal(true)
case GreaterThan(l, r) => LessThanOrEqual(l, r)
case GreaterThanOrEqual(l, r) => LessThan(l, r)
case LessThan(l, r) => GreaterThanOrEqual(l, r)
case LessThanOrEqual(l, r) => GreaterThan(l, r)
case Not(e) => e
case _ => not
}

// Turn "if (true) a else b" into "a", and if (false) a else b" into "b".
case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
}
Expand Down
Loading