Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -77,8 +77,9 @@ object ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,39 @@ trait HiveTypeCoercion {
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))

// we should cast all timestamp/date/string compare into string compare
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about turning Date/Timestamp comparison to Long comparison? String and long representations of Timestamp are both accurate to seconds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems cast('1970-01-01' as date) < cast('1970-01-01 00:00:00' as timestamp)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK... verified this behavior with Hive, I've no idea about this :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Michael agreed to leave the whole ordering and comparing stuff in a separated PR :)


case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
Expand Down Expand Up @@ -283,6 +302,8 @@ trait HiveTypeCoercion {
// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
// DateType should be null if be cast to boolean.
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

Expand Down Expand Up @@ -119,6 +119,7 @@ package object dsl {
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
Expand Down Expand Up @@ -174,6 +175,9 @@ package object dsl {
/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type date */
def date = AttributeReference(s, DateType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
override def foldable = child.foldable

override def nullable = (child.dataType, dataType) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (StringType, DateType) => true
case _ => child.nullable
}

Expand All @@ -42,6 +45,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// UDFToString
private[this] def castToString: Any => Any = child.dataType match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
case DateType => buildCast[Date](_, dateToString)
case TimestampType => buildCast[Timestamp](_, timestampToString)
case _ => buildCast[Any](_, _.toString)
}
Expand All @@ -56,7 +60,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType =>
buildCast[String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
// Hive would return null when cast from date to boolean
buildCast[Date](_, d => null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a comment here would be good. It's really unintuitive here to see a timestamp can be casted to a boolean while a date has to be null.

case LongType =>
buildCast[Long](_, _ != 0)
case IntegerType =>
Expand Down Expand Up @@ -95,6 +102,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
buildCast[Short](_, s => new Timestamp(s))
case ByteType =>
buildCast[Byte](_, b => new Timestamp(b))
case DateType =>
buildCast[Date](_, d => new Timestamp(d.getTime))
// TimestampWritable.decimalToTimestamp
case DecimalType =>
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
Expand Down Expand Up @@ -130,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def timestampToString(ts: Timestamp): String = {
val timestampString = ts.toString
val formatted = Cast.threadLocalDateFormat.get.format(ts)
val formatted = Cast.threadLocalTimestampFormat.get.format(ts)

if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
formatted + timestampString.substring(19)
Expand All @@ -139,13 +148,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
}

// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def timestampToDateString(ts: Timestamp): String = {
Cast.threadLocalDateFormat.get.format(ts)
}

// DateConverter
private[this] def castToDate: Any => Any = child.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case StringType =>
buildCast[String](_, s =>
try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
// throw valid precision more than seconds, according to Hive.
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000))
// Hive throws this exception as a Semantic Exception
// It is never possible to compare result when hive return with exception, so we can return null
// NULL is more reasonable here, since the query itself obeys the grammar.
case _ => _ => null
}

// Date cannot be cast to long, according to hive
private[this] def dateToLong(d: Date) = null

// Date cannot be cast to double, according to hive
private[this] def dateToDouble(d: Date) = null

// Converts Date to string according to Hive DateWritable convention
private[this] def dateToString(d: Date): String = {
Cast.threadLocalDateFormat.get.format(d)
}

// LongConverter
private[this] def castToLong: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
case DecimalType =>
Expand All @@ -154,13 +198,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}

// IntConverter
private[this] def castToInt: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
case DecimalType =>
Expand All @@ -169,13 +216,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}

// ShortConverter
private[this] def castToShort: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
case DateType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Datetype can not be casted into IntegerType / ShortType / LongType, let's remove this, raise exception in compile time probably better than in runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hive returns NULL when cast from DATE to INT., etc

buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
case DecimalType =>
Expand All @@ -184,13 +234,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}

// ByteConverter
private[this] def castToByte: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
case DecimalType =>
Expand All @@ -199,27 +252,33 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

// DecimalConverter
private[this] def castToDecimal: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
// Note that we lose precision here.
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
case x: NumericType =>
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}

// DoubleConverter
private[this] def castToDouble: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType =>
Expand All @@ -228,13 +287,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

// FloatConverter
private[this] def castToFloat: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType =>
Expand All @@ -245,17 +307,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

private[this] lazy val cast: Any => Any = dataType match {
case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case DateType => castToDate
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
}

override def eval(input: Row): Any = {
Expand All @@ -267,6 +330,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
object Cast {
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
override def initialValue() = {
new SimpleDateFormat("yyyy-MM-dd")
}
}

// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
override def initialValue() = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.types._

Expand All @@ -33,6 +33,7 @@ object Literal {
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(d, DecimalType)
case t: Timestamp => Literal(t, TimestampType)
case d: Date => Literal(d, DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.types

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
Expand Down Expand Up @@ -250,6 +250,16 @@ case object TimestampType extends NativeType {
}
}

case object DateType extends NativeType {
private[sql] type JvmType = Date

@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }

private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Date, y: Date) = x.compareTo(y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked the logic of java.sql.Date.compareTo, and it is not the same as DateWritable.compareTo, which is the internal representation in Hive. The former will compare its milliseconds, but the later only compare the days since Epoch, probably we need to follow the same semantic with Hive here.
BTW: If we change the logic here, does that also mean we needn't cast the date to string for BinaryPredicate in HiveTypeCoercion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to modify compareTo from TimestampType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used for ordering, not for data comparison.

}
}

abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
Expand Down
Loading