Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None)
def constructorFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
constructorFor(tpe, None, walkedTypePath)
}

private def constructorFor(
tpe: `Type`,
path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
path: Option[Expression],
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {

/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
.map(p => GetStructField(p, ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))
def addToPathOrdinal(
ordinal: Int,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => GetStructField(p, ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

/** Returns the current path or `BoundReference`. */
def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
def getPath: Expression = {
val dataType = schemaFor(tpe).dataType
if (path.isDefined) {
path.get
} else {
upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
}
}

/**
* When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* don't need to cast struct type because there must be `UnresolvedExtractValue` or
* `GetStructField` wrapping it, thus we only need to handle leaf type.
*/
def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _ => UpCast(expr, expected, walkedTypePath)
}

tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath

case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
WrapOption(constructorFor(optType, path))
val className = getClassNameFromType(optType)
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
WrapOption(constructorFor(optType, path, newTypePath))

case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
Expand Down Expand Up @@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
primitiveMethod.map { method =>
Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p)),
p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
Expand All @@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val arrayData =
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p)),
p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
Expand All @@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
arrayData :: Nil)

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t

val keyData =
Invoke(
MapObjects(
p => constructorFor(keyType, Some(p)),
p => constructorFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
Expand All @@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
val valueData =
Invoke(
MapObjects(
p => constructorFor(valueType, Some(p)),
p => constructorFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
Expand Down Expand Up @@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val dataType = schemaFor(fieldType).dataType

val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
if (cls.getName startsWith "scala.Tuple") {
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
constructorFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
constructorFor(fieldType, Some(addToPath(fieldName)))
constructorFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
Expand Down Expand Up @@ -1169,3 +1170,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}

/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I just use prettyString to show the field path like a, a.b, should we also show the type path like we did before?

s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}

private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
}

def apply(plan: LogicalPlan): LogicalPlan = {
plan transformAllExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u

case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
fail(child, to, walkedTypePath)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
fail(child, to, walkedTypePath)
case (from, to) if illegalNumericPrecedence(from, to) =>
fail(child, to, walkedTypePath)
case (TimestampType, DateType) =>
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object HiveTypeCoercion {

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
private val numericPrecedence =
private[sql] val numericPrecedence =
IndexedSeq(
ByteType,
ShortType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
Expand Down Expand Up @@ -235,12 +236,13 @@ case class ExpressionEncoder[T](

val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
val optimizedPlan = SimplifyCasts(analyzedPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should just use the full optimizer here. I guess for now it won't do anything, but since it should never change the answer and we might improve it later that might make more sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also thinking about if we should introduce sqlContext here, and use its analyzer and optimizer. For now our encoder resolution is case sensitive regardless of the CASE_SENSITIVE config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a SQLContext but a CatalystConfig would be reasonable. I wonder if it should be a different setting than SQL case sensitivity resolution?

On one hand, Scala/Java are always case sensitive so it seems reasonable to preserve that. On the other hand if you loading from something like hive it would be annoying to have to fix all the columns by hand.

@rxin, thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe encoders should be case sensitive all the time to begin with? It is programming language after all, which is case sensitive. If users complain, we can consider adding them in the future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM


// In order to construct instances of inner classes (for example those declared in a REPL cell),
// we need an instance of the outer scope. This rule substitues those outer objects into
// expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
// registry.
copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
if (outer == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ object Cast {
}

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression with CodegenFallback {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand Down Expand Up @@ -915,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType)
"""
}
}

/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
*/
case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String])
extends UnaryExpression with Unevaluable {
override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

/**
* Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
* Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
* StructType.
*/
def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
case _ => false
}

/**
* Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
* can be casted into `other` safely without losing any precision or range.
*/
private[sql] def isTighterThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
(precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
case dt: IntegralType =>
isTighterThan(DecimalType.forType(dt))
case _ => false
}

/**
* The default size of a value of the DecimalType is 4096 bytes.
*/
Expand Down
Loading