Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ class ArrayType(DataType):

"""

def __init__(self, elementType, containsNull=False):
def __init__(self, elementType, containsNull=True):
"""Creates an ArrayType

:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.

>>> ArrayType(StringType) == ArrayType(StringType, False)
>>> ArrayType(StringType) == ArrayType(StringType, True)
True
>>> ArrayType(StringType, True) == ArrayType(StringType)
>>> ArrayType(StringType, False) == ArrayType(StringType)
False
"""
self.elementType = elementType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ object ScalaReflection {
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ case object FloatType extends FractionalType {
}

object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ case class OptionalData(

case class ComplexData(
arrayField: Seq[Int],
mapField: Map[Int, String],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
structField: PrimitiveData)

case class GenericData[A](
Expand Down Expand Up @@ -116,8 +118,22 @@ class ScalaReflectionSuite extends FunSuite {
val schema = schemaFor[ComplexData]
assert(schema === Schema(
StructType(Seq(
StructField("arrayField", ArrayType(IntegerType), nullable = true),
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
StructField(
"arrayField",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
nullable = true),
StructField(
"mapField",
MapType(IntegerType, LongType, valueContainsNull = false),
nullable = true),
StructField(
"mapFieldValueContainsNull",
MapType(IntegerType, LongType, valueContainsNull = true),
nullable = true),
StructField(
"structField",
StructType(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ public abstract class DataType {

/**
* Creates an ArrayType by specifying the data type of elements ({@code elementType}).
* The field of {@code containsNull} is set to {@code false}.
* The field of {@code containsNull} is set to {@code true}.
*/
public static ArrayType createArrayType(DataType elementType) {
if (elementType == null) {
throw new IllegalArgumentException("elementType should not be null.");
}

return new ArrayType(elementType, false);
return new ArrayType(elementType, true);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ case class Sort(
object ExistingRdd {
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.orNull
case s: Seq[Any] => s.map(convertToCatalyst)
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ private[sql] object CatalystConverter {
// This is mostly Parquet convention (see, e.g., `ConversionPatterns`).
// Note that "array" for the array elements is chosen by ParquetAvro.
// Using a different value will result in Parquet silently dropping columns.
val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag"
val ARRAY_ELEMENTS_SCHEMA_NAME = "array"
val MAP_KEY_SCHEMA_NAME = "key"
val MAP_VALUE_SCHEMA_NAME = "value"
Expand All @@ -82,6 +83,9 @@ private[sql] object CatalystConverter {
case ArrayType(elementType: DataType, false) => {
new CatalystArrayConverter(elementType, fieldIndex, parent)
}
case ArrayType(elementType: DataType, true) => {
new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent)
}
case StructType(fields: Seq[StructField]) => {
new CatalystStructConverter(fields.toArray, fieldIndex, parent)
}
Expand Down Expand Up @@ -567,6 +571,85 @@ private[parquet] class CatalystNativeArrayConverter(
}
}

/**
* A `parquet.io.api.GroupConverter` that converts a single-element groups that
* match the characteristics of an array contains null (see
* [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an
* [[org.apache.spark.sql.catalyst.types.ArrayType]].
*
* @param elementType The type of the array elements (complex or primitive)
* @param index The position of this (array) field inside its parent converter
* @param parent The parent converter
* @param buffer A data buffer
*/
private[parquet] class CatalystArrayContainsNullConverter(
val elementType: DataType,
val index: Int,
protected[parquet] val parent: CatalystConverter,
protected[parquet] var buffer: Buffer[Any])
extends CatalystConverter {

def this(elementType: DataType, index: Int, parent: CatalystConverter) =
this(
elementType,
index,
parent,
new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE))

protected[parquet] val converter: Converter = new CatalystConverter {

private var current: Any = null

val converter = CatalystConverter.createConverter(
new CatalystConverter.FieldType(
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
elementType,
false),
fieldIndex = 0,
parent = this)

override def getConverter(fieldIndex: Int): Converter = converter

override def end(): Unit = parent.updateField(index, current)

override def start(): Unit = {
current = null
}

override protected[parquet] val size: Int = 1
override protected[parquet] val index: Int = 0
override protected[parquet] val parent = CatalystArrayContainsNullConverter.this

override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = {
current = value
}

override protected[parquet] def clearBuffer(): Unit = {}
}

override def getConverter(fieldIndex: Int): Converter = converter

// arrays have only one (repeated) field, which is its elements
override val size = 1

override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = {
buffer += value
}

override protected[parquet] def clearBuffer(): Unit = {
buffer.clear()
}

override def start(): Unit = {}

override def end(): Unit = {
assert(parent != null)
// here we need to make sure to use ArrayScalaType
parent.updateField(index, buffer.toArray.toSeq)
clearBuffer()
}
}

/**
* This converter is for multi-element groups of primitive or complex types
* that have repetition level optional or required (so struct fields).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
case t @ ArrayType(_, false) => writeArray(
case t @ ArrayType(_, _) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
case t @ MapType(_, _, _) => writeMap(
Expand Down Expand Up @@ -228,45 +228,57 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
}

// TODO: support null values, see
// https://issues.apache.org/jira/browse/SPARK-1649
private[parquet] def writeArray(
schema: ArrayType,
array: CatalystConverter.ArrayScalaType[_]): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.size > 0) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
while(i < array.size) {
writeValue(elementType, array(i))
i = i + 1
if (schema.containsNull) {
writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0)
var i = 0
while (i < array.size) {
writer.startGroup()
if (array(i) != null) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
writeValue(elementType, array(i))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0)
} else {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
while (i < array.size) {
writeValue(elementType, array(i))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
}

// TODO: support null values, see
// https://issues.apache.org/jira/browse/SPARK-1649
private[parquet] def writeMap(
schema: MapType,
map: CatalystConverter.MapScalaType[_, _]): Unit = {
writer.startGroup()
if (map.size > 0) {
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)
writer.startGroup()
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
for(key <- map.keys) {
for ((key, value) <- map) {
writer.startGroup()
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
writeValue(schema.keyType, key)
writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
if (value != null) {
writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
writeValue(schema.valueType, value)
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
}
writer.endGroup()
}
writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
for(value <- map.values) {
writeValue(schema.valueType, value)
}
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
writer.endGroup()
writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0)
}
writer.endGroup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetOriginalType.LIST => { // TODO: check enums!
assert(groupType.getFieldCount == 1)
val field = groupType.getFields.apply(0)
ArrayType(toDataType(field, isBinaryAsString), containsNull = false)
if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) {
val bag = field.asGroupType()
assert(bag.getFieldCount == 1)
ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true)
} else {
ArrayType(toDataType(field, isBinaryAsString), containsNull = false)
}
}
case ParquetOriginalType.MAP => {
assert(
Expand All @@ -129,28 +135,32 @@ private[parquet] object ParquetTypesConverter extends Logging {
assert(
keyValueGroup.getFieldCount == 2,
"Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)

val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString)
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
// at here.
MapType(keyType, valueType)
MapType(keyType, valueType,
keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
}
case _ => {
// Note: the order of these checks is important!
if (correspondsToMap(groupType)) { // MapType
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)

val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString)
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
// at here.
MapType(keyType, valueType)
MapType(keyType, valueType,
keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
} else if (correspondsToArray(groupType)) { // ArrayType
val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString)
ArrayType(elementType, containsNull = false)
val field = groupType.getFields.apply(0)
if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) {
val bag = field.asGroupType()
assert(bag.getFieldCount == 1)
ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true)
} else {
ArrayType(toDataType(field, isBinaryAsString), containsNull = false)
}
} else { // everything else: StructType
val fields = groupType
.getFields
Expand Down Expand Up @@ -249,13 +259,27 @@ private[parquet] object ParquetTypesConverter extends Logging {
inArray = true)
ConversionPatterns.listType(repetition, name, parquetElementType)
}
case ArrayType(elementType, true) => {
val parquetElementType = fromDataType(
elementType,
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
nullable = true,
inArray = false)
ConversionPatterns.listType(
repetition,
name,
new ParquetGroupType(
Repetition.REPEATED,
CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME,
parquetElementType))
}
case StructType(structFields) => {
val fields = structFields.map {
field => fromDataType(field.dataType, field.name, field.nullable, inArray = false)
}
new ParquetGroupType(repetition, name, fields)
}
case MapType(keyType, valueType, _) => {
case MapType(keyType, valueType, valueContainsNull) => {
val parquetKeyType =
fromDataType(
keyType,
Expand All @@ -266,7 +290,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
fromDataType(
valueType,
CatalystConverter.MAP_VALUE_SCHEMA_NAME,
nullable = false,
nullable = valueContainsNull,
inArray = false)
ConversionPatterns.mapType(
repetition,
Expand Down
Loading