Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,14 @@ object JavaTypeInference {
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
val (_, nullable) = inferDataType(fieldType)
val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
val setter = if (nullable) {
constructor
} else {
AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
}
p.getWriteMethod.getName -> setter
}.toMap

val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection {
val cls = getClassFromType(tpe)

val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val dataType = schemaFor(fieldType).dataType
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
Expand All @@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection {
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
constructorFor(
val constructor = constructorFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)

if (!nullable) {
AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
} else {
constructor
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
"""
}
}

/**
* Asserts that input values of a non-nullable child expression are not null.
*
* Note that there are cases where `child.nullable == true`, while we still needs to add this
* assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
case class AssertNotNull(
child: Expression, parentType: String, fieldName: String, fieldType: String)
extends UnaryExpression {

override def dataType: DataType = child.dataType

override def nullable: Boolean = false

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val childGen = child.gen(ctx)

ev.isNull = "false"
ev.value = childGen.value

s"""
${childGen.code}

if (${childGen.isNull}) {
throw new RuntimeException(
"Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
);
}
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[StringLongClass]
val cls = classOf[StringLongClass]


{
val attrs = Seq('a.string, 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
Seq(
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
Expand All @@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest {
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
val expected = NewInstance(
cls,
toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
Seq(
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
Expand All @@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest {
val expected: Expression = NewInstance(
cls,
Seq(
'a.int.cast(LongType),
AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
If(
'b.struct('a.int, 'b.long).isNull,
Literal.create(null, ObjectType(innerCls)),
Expand All @@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest {
Seq(
toExternalString(
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))),
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
false,
ObjectType(innerCls))
)),
Expand All @@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest {
cls,
Seq(
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)),
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
false,
ObjectType(cls)),
'b.int.cast(LongType)),
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -64,7 +65,7 @@ import org.apache.spark.util.Utils
class Dataset[T] private[sql](
@transient override val sqlContext: SQLContext,
@transient override val queryExecution: QueryExecution,
tEncoder: Encoder[T]) extends Queryable with Serializable {
tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {

/**
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.sql.Timestamp;
import java.util.*;

import com.google.common.base.Objects;
import org.junit.rules.ExpectedException;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
Expand All @@ -39,7 +41,6 @@
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructType;

import static org.apache.spark.sql.functions.*;
Expand Down Expand Up @@ -741,4 +742,127 @@ public void testJavaBeanEncoder2() {
context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
ds.collect();
}

public class SmallBean implements Serializable {
private String a;

private int b;

public int getB() {
return b;
}

public void setB(int b) {
this.b = b;
}

public String getA() {
return a;
}

public void setA(String a) {
this.a = a;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SmallBean smallBean = (SmallBean) o;
return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a);
}

@Override
public int hashCode() {
return Objects.hashCode(a, b);
}
}

public class NestedSmallBean implements Serializable {
private SmallBean f;

public SmallBean getF() {
return f;
}

public void setF(SmallBean f) {
this.f = f;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NestedSmallBean that = (NestedSmallBean) o;
return Objects.equal(f, that.f);
}

@Override
public int hashCode() {
return Objects.hashCode(f);
}
}

@Rule
public transient ExpectedException nullabilityCheck = ExpectedException.none();

@Test
public void testRuntimeNullabilityCheck() {
OuterScopes.addOuterScope(this);

StructType schema = new StructType()
.add("f", new StructType()
.add("a", StringType, true)
.add("b", IntegerType, true), true);

// Shouldn't throw runtime exception since it passes nullability check.
{
Row row = new GenericRow(new Object[] {
new GenericRow(new Object[] {
"hello", 1
})
});

DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));

SmallBean smallBean = new SmallBean();
smallBean.setA("hello");
smallBean.setB(1);

NestedSmallBean nestedSmallBean = new NestedSmallBean();
nestedSmallBean.setF(smallBean);

Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
}

// Shouldn't throw runtime exception when parent object (`ClassData`) is null
{
Row row = new GenericRow(new Object[] { null });

DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));

NestedSmallBean nestedSmallBean = new NestedSmallBean();
Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
}

nullabilityCheck.expect(RuntimeException.class);
nullabilityCheck.expectMessage(
"Null value appeared in non-nullable field " +
"test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");

{
Row row = new GenericRow(new Object[] {
new GenericRow(new Object[] {
"hello", null
})
});

DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));

ds.collect();
}
}
}
33 changes: 33 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.language.postfixOps

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


class DatasetSuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage)
}

test("runtime nullability check") {
val schema = StructType(Seq(
StructField("f", StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", IntegerType, nullable = false)
)), nullable = true)
))

def buildDataset(rows: Row*): Dataset[NestedStruct] = {
val rowRDD = sqlContext.sparkContext.parallelize(rows)
sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
}

checkAnswer(
buildDataset(Row(Row("hello", 1))),
NestedStruct(ClassData("hello", 1))
)

// Shouldn't throw runtime exception when parent object (`ClassData`) is null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is: if the parent is null, we should shortcut the execution to return null directly, instead of going into the field and trigger the null check. However, looks like we only do this shortcut for product type by If(IsNull...), we may also need to handle array type and map type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Verified that the current mechanism doesn't play well with primitive arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually what the test case I constructed reflected is another separate bug, which has been fixed in PR #10401.

Discussed with @cloud-fan offline. What he meant was, we should also add AssertNotNull for array types and map types, with which I totally agree, but I think it would be nice to be added in a separate PR.

assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))

val message = intercept[RuntimeException] {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage

assert(message.contains(
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
}
}

case class ClassData(a: String, b: Int)
case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)

case class NestedStruct(f: ClassData)

/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.
Expand Down