From 8f0a0bfdb044acf185548b214e904b00d3ba3be3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 9 May 2016 22:09:35 +0800 Subject: [PATCH 1/7] null check for SparkSession.createDataFrame --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- .../spark/sql/catalyst/encoders/RowEncoder.scala | 12 +++++------- .../sql/catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/objects.scala | 4 +++- .../scala/org/apache/spark/sql/types/Decimal.scala | 1 + .../scala/org/apache/spark/sql/SparkSession.scala | 4 ++-- .../scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++-- .../org/apache/spark/sql/test/SQLTestUtils.scala | 6 +----- 8 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d158a64a85bc..d942d611ccdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection { * Returns true if the value of this data type is same between internal and external. */ def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cfde3bfbecbd..a44bbff40cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -49,8 +49,7 @@ object RowEncoder { private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + case dt if ScalaReflection.isNativeType(dt) => inputObject case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) @@ -84,10 +83,10 @@ object RowEncoder { "fromJavaDate", inputObject :: Nil) - case _: DecimalType => + case d: DecimalType => StaticInvoke( Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, + d, "fromDecimal", inputObject :: Nil) @@ -130,7 +129,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => val fieldValue = serializerFor( - GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), + GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)), f.dataType ) if (f.nullable) { @@ -199,8 +198,7 @@ object RowEncoder { } private def deserializerFor(input: Expression): Expression = input.dataType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => input + case dt if ScalaReflection.isNativeType(dt) => input case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 99f156a935b5..a38f1ec09156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { - override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index dbaff1625ed5..ba92a3c4d733 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -699,6 +699,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) case class GetExternalRowField( child: Expression, index: Int, + fieldName: String, dataType: DataType) extends UnaryExpression with NonSQLExpression { override def nullable: Boolean = false @@ -722,7 +723,8 @@ case class GetExternalRowField( } if (${row.value}.isNullAt($index)) { - throw new RuntimeException("The ${index}th field of input row cannot be null."); + throw new RuntimeException("The ${index}th field '$fieldName' of input row " + + "cannot be null."); } final ${ctx.javaType(dataType)} ${ev.value} = $getField; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6f4ec6b70191..2f7422b7420d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -386,6 +386,7 @@ object Decimal { def fromDecimal(value: Any): Decimal = { value match { case j: java.math.BigDecimal => apply(j) + case d: BigDecimal => apply(d) case d: Decimal => d } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2a893c6478d1..0027606441d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -490,8 +490,8 @@ class SparkSession private( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[InternalRow]) + val encoder = RowEncoder(schema) + rowRDD.map(encoder.toRow) } else { rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3cb4e52c6d41..301319ebe82b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -505,7 +505,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val schema = StructType(Seq( StructField("f", StructType(Seq( StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) + StructField("b", IntegerType, nullable = true) )), nullable = true) )) @@ -672,7 +672,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val message = intercept[Exception] { df.collect() }.getMessage - assert(message.contains("The 0th field of input row cannot be null")) + assert(message.contains("The 0th field 'i' of input row cannot be null")) + } + + test("row nullability mismatch") { + val schema = new StructType().add("a", StringType, true).add("b", StringType, false) + val rdd = sqlContext.sparkContext.parallelize(Row(null, "123") :: Row("234", null) :: Nil) + val message = intercept[Exception] { + sqlContext.createDataFrame(rdd, schema).collect() + }.getMessage + assert(message.contains("The 1th field 'b' of input row cannot be null")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6d2b95e83a44..e286ddeca524 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -217,11 +217,7 @@ private[sql] trait SQLTestUtils case FilterExec(_, child) => child } - val childRDD = withoutFilters - .execute() - .map(row => Row.fromSeq(row.copy().toSeq(schema))) - - sqlContext.createDataFrame(childRDD, schema) + sqlContext.internalCreateDataFrame(withoutFilters.execute(), schema) } /** From 7419a525f3471b7a225721436150be78ee94e22b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 May 2016 10:50:30 +0800 Subject: [PATCH 2/7] minor cleanup --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa2273668b44..b3a55198eda4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -169,7 +169,6 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt - case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -178,7 +177,6 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) - case _: NullType => ObjectType(classOf[java.lang.Object]) } private def deserializerFor(schema: StructType): Expression = { From 0915a71e2834617db28dc13f1a68c11c71c35009 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 May 2016 13:15:14 +0800 Subject: [PATCH 3/7] fix mllib --- mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9166faa54de5..28e4966f918a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqItemsets.map { x => - Row(x.items, x.freq) + Row(x.items.toSeq, x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 4344ab1bade9..41a9141ae4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -633,7 +633,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqSequences.map { x => - Row(x.sequence, x.freq) + Row(x.sequence.toSeq.map(_.toSeq), x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } From 3acf24f1ed256d4f2411b5c67a688587cba589ab Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 May 2016 15:53:22 +0800 Subject: [PATCH 4/7] fix R --- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 36173a49250b..6fdb4e4afc4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -110,6 +110,12 @@ private[sql] object SQLUtils { data match { case d: java.lang.Double if dataType == FloatType => new java.lang.Float(d) + case a: Array[Boolean] if dataType.isInstanceOf[ArrayType] => a.toSeq + case a: Array[Int] if dataType.isInstanceOf[ArrayType] => a.toSeq + case a: Array[Double] if dataType.isInstanceOf[ArrayType] => a.toSeq + case a if a.getClass.isArray && dataType.isInstanceOf[ArrayType] => + a.asInstanceOf[Array[Object]].toSeq + .map(doConversion(_, dataType.asInstanceOf[ArrayType].elementType)) case _ => data } } @@ -120,14 +126,14 @@ private[sql] object SQLUtils { val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }.toSeq) + }) } private[sql] def rowToRBytes(row: Row): Array[Byte] = { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray + val cols = row.toSeq.toArray SerDe.writeObject(dos, cols) bos.toByteArray() } @@ -198,7 +204,7 @@ private[sql] object SQLUtils { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] + val fields = SerDe.readList(dis) Row.fromSeq(fields) case _ => null } From 225128d48c9a9c913339a6291947675509991317 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 17 May 2016 17:06:41 +0800 Subject: [PATCH 5/7] rebase --- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 12 +++--------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 28e4966f918a..9166faa54de5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqItemsets.map { x => - Row(x.items.toSeq, x.freq) + Row(x.items, x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 41a9141ae4ef..4344ab1bade9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -633,7 +633,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqSequences.map { x => - Row(x.sequence.toSeq.map(_.toSeq), x.freq) + Row(x.sequence, x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 5df5b1c8542b..ffb606f2c66d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -108,12 +108,6 @@ private[sql] object SQLUtils { data match { case d: java.lang.Double if dataType == FloatType => new java.lang.Float(d) - case a: Array[Boolean] if dataType.isInstanceOf[ArrayType] => a.toSeq - case a: Array[Int] if dataType.isInstanceOf[ArrayType] => a.toSeq - case a: Array[Double] if dataType.isInstanceOf[ArrayType] => a.toSeq - case a if a.getClass.isArray && dataType.isInstanceOf[ArrayType] => - a.asInstanceOf[Array[Object]].toSeq - .map(doConversion(_, dataType.asInstanceOf[ArrayType].elementType)) case _ => data } } @@ -124,14 +118,14 @@ private[sql] object SQLUtils { val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }) + }.toSeq) } private[sql] def rowToRBytes(row: Row): Array[Byte] = { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - val cols = row.toSeq.toArray + val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray SerDe.writeObject(dos, cols) bos.toByteArray() } @@ -202,7 +196,7 @@ private[sql] object SQLUtils { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis) + val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] Row.fromSeq(fields) case _ => null } From 57efddbf737686c536933def1937011575caec24 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 May 2016 12:04:35 +0800 Subject: [PATCH 6/7] fix ml --- mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9166faa54de5..28e4966f918a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqItemsets.map { x => - Row(x.items, x.freq) + Row(x.items.toSeq, x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } From f533188c8c7bb295307ce141d0ba1f74788bd21c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 May 2016 18:06:58 +0800 Subject: [PATCH 7/7] fix R --- .../src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index ffb606f2c66d..486a440b6f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.JavaConverters._ import scala.util.matching.Regex import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -108,6 +109,8 @@ private[sql] object SQLUtils { data match { case d: java.lang.Double if dataType == FloatType => new java.lang.Float(d) + // Scala Map is the only allowed external type of map type in Row. + case m: java.util.Map[_, _] => m.asScala case _ => data } } @@ -118,7 +121,7 @@ private[sql] object SQLUtils { val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }.toSeq) + }) } private[sql] def rowToRBytes(row: Row): Array[Byte] = {