Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ abstract class Expression extends TreeNode[Expression] {
case i: IntegralType =>
f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
case i: FractionalType =>
f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
case other => sys.error(s"Type $other does not support numeric operations")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.types

import java.sql.Timestamp

import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
import scala.util.parsing.combinator.RegexParsers
Expand Down Expand Up @@ -240,6 +241,7 @@ object FractionalType {
}
abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
}

case object DecimalType extends FractionalType {
Expand All @@ -248,6 +250,7 @@ case object DecimalType extends FractionalType {
private[sql] val numeric = implicitly[Numeric[BigDecimal]]
private[sql] val fractional = implicitly[Fractional[BigDecimal]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = BigDecimalAsIfIntegral
def simpleString: String = "decimal"
}

Expand All @@ -257,6 +260,7 @@ case object DoubleType extends FractionalType {
private[sql] val numeric = implicitly[Numeric[Double]]
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
def simpleString: String = "double"
}

Expand All @@ -266,6 +270,7 @@ case object FloatType extends FractionalType {
private[sql] val numeric = implicitly[Numeric[Float]]
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
def simpleString: String = "float"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp

import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.scalautils.TripleEqualsSupport.Spread

import org.apache.spark.sql.catalyst.types._

Expand Down Expand Up @@ -129,6 +131,13 @@ class ExpressionEvaluationSuite extends FunSuite {
}
}

def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
actual.asInstanceOf[Double] shouldBe expected
}

test("IN") {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
Expand Down Expand Up @@ -467,6 +476,29 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(c1 % c2, 1, row)
}

test("fractional arithmetic") {
val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null))
val c1 = 'a.double.at(0)
val c2 = 'a.double.at(1)
val c3 = 'a.double.at(2)
val c4 = 'a.double.at(3)

checkEvaluation(UnaryMinus(c1), -1.1, row)
checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0)
checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3.1, row)
checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row)
checkEvaluation(Add(Literal(null, DoubleType), c2), null, row)
checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row)

checkEvaluation(-c1, -1.1, row)
checkEvaluation(c1 + c2, 3.1, row)
checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
}

test("BinaryComparison") {
val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
val c1 = 'a.int.at(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1 true 0.5
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class HiveQuerySuite extends HiveComparisonTest {
createQueryTest("division",
"SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")

createQueryTest("modulus",
"SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1")

test("Query expressed in SQL") {
setConf("spark.sql.dialect", "sql")
assert(sql("SELECT 1").collect() === Array(Seq(1)))
Expand Down