Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
StructField("freq", LongType))
val schema = StructType(fields)
val rowDataRDD = model.freqItemsets.map { x =>
Row(x.items, x.freq)
Row(x.items.toSeq, x.freq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call toSeq at here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need. This is a special case, FPGrowthModel has a type parameter and we use FPGrowthModel[_] here. So x.items returns Object[] instead of T[] as we expected and doesn't match the schema.

}
sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection {
* Returns true if the value of this data type is same between internal and external.
*/
def isNativeType(dt: DataType): Boolean = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => true
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why CalendarIntervalType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we don't have an external representation of it

case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ object RowEncoder {
private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
case dt if ScalaReflection.isNativeType(dt) => inputObject

case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)

Expand Down Expand Up @@ -151,7 +150,7 @@ object RowEncoder {
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val fieldValue = serializerFor(
GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)),
f.dataType
)
if (f.nullable) {
Expand Down Expand Up @@ -193,7 +192,6 @@ object RowEncoder {

private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case CalendarIntervalType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
Expand All @@ -202,7 +200,6 @@ object RowEncoder {
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
case _: NullType => ObjectType(classOf[java.lang.Object])
}

private def deserializerFor(schema: StructType): Expression = {
Expand All @@ -222,8 +219,7 @@ object RowEncoder {
}

private def deserializerFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => input
case dt if ScalaReflection.isNativeType(dt) => input

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression {

override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"
override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
case class GetExternalRowField(
child: Expression,
index: Int,
fieldName: String,
dataType: DataType) extends UnaryExpression with NonSQLExpression {

override def nullable: Boolean = false
Expand All @@ -716,7 +717,8 @@ case class GetExternalRowField(
}

if (${row.value}.isNullAt($index)) {
throw new RuntimeException("The ${index}th field of input row cannot be null.");
throw new RuntimeException("The ${index}th field '$fieldName' of input row " +
"cannot be null.");
}

final ${ctx.javaType(dataType)} ${ev.value} = $getField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ class SparkSession private(
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val catalystRows = if (needsConversion) {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
rowRDD.map(converter(_).asInstanceOf[InternalRow])
val encoder = RowEncoder(schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we already do null check in RowEncoder, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

rowRDD.map(encoder.toRow)
} else {
rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.api.r

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.JavaConverters._
import scala.util.matching.Regex

import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
Expand Down Expand Up @@ -108,6 +109,8 @@ private[sql] object SQLUtils {
data match {
case d: java.lang.Double if dataType == FloatType =>
new java.lang.Float(d)
// Scala Map is the only allowed external type of map type in Row.
case m: java.util.Map[_, _] => m.asScala
case _ => data
}
}
Expand All @@ -118,7 +121,7 @@ private[sql] object SQLUtils {
val num = SerDe.readInt(dis)
Row.fromSeq((0 until num).map { i =>
doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
}.toSeq)
})
}

private[sql] def rowToRBytes(row: Row): Array[Byte] = {
Expand Down
13 changes: 11 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val schema = StructType(Seq(
StructField("f", StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", IntegerType, nullable = false)
StructField("b", IntegerType, nullable = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new null check, we will trigger error earlier than this test expected. This test is testing the AssertNotNull expression, which is used for converting nullable column to not-nullable object field(like primitive int).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. so the new test (row nullability mismatch) is effectively covering this case? Then, should we change the name of this test? Will we hit the exception that is checked by this test in any other cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just want to make sure we are not losing test coverage)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, row nullability mismatch checks the error that we pass in a null column while this column is declared as not nullable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

)), nullable = true)
))

Expand Down Expand Up @@ -684,7 +684,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val message = intercept[Exception] {
df.collect()
}.getMessage
assert(message.contains("The 0th field of input row cannot be null"))
assert(message.contains("The 0th field 'i' of input row cannot be null"))
}

test("row nullability mismatch") {
val schema = new StructType().add("a", StringType, true).add("b", StringType, false)
val rdd = sqlContext.sparkContext.parallelize(Row(null, "123") :: Row("234", null) :: Nil)
val message = intercept[Exception] {
sqlContext.createDataFrame(rdd, schema).collect()
}.getMessage
assert(message.contains("The 1th field 'b' of input row cannot be null"))
}

test("createTempView") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ private[sql] trait SQLTestUtils
case FilterExec(_, child) => child
}

val childRDD = withoutFilters
.execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema)))

spark.createDataFrame(childRDD, schema)
spark.internalCreateDataFrame(withoutFilters.execute(), schema)
}

/**
Expand Down