-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-8221][SQL]Add pmod function #6783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { | |
| override def symbol: String = "min" | ||
| override def prettyName: String = symbol | ||
| } | ||
|
|
||
| case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should override toString to be s"pmod($left, $right)" |
||
|
|
||
| override def toString: String = s"pmod($left, $right)" | ||
|
|
||
| override def symbol: String = "pmod" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually this |
||
|
|
||
| protected def checkTypesInternal(t: DataType) = | ||
| TypeUtils.checkForNumericExpr(t, "pmod") | ||
|
|
||
| override def inputType: AbstractDataType = NumericType | ||
|
|
||
| protected override def nullSafeEval(left: Any, right: Any) = | ||
| dataType match { | ||
| case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) | ||
| case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) | ||
| case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) | ||
| case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) | ||
| case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) | ||
| case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) | ||
| case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about follow
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea , but seems like the clac result would be a little bit different with the expected result. private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
case i: DecimalType => i.asIntegral.asInstanceOf[Integral[Any]]
}
protected override def evalInternal(a: Any, n: Any) =
integral.rem(integral.plus(integral.rem(a, n), n), n)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok maybe you can send another PR to fix |
||
| } | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
| dataType match { | ||
| case dt: DecimalType => | ||
| val decimalAdd = "$plus" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, just check the code, you probably need to override the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, as we talked offline, this is a exceptional case in UPDATE: let's ignore the |
||
| s""" | ||
| ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fyi there is a bug here -- if we use pmod twice, this will fail codegen because r is not unique. |
||
| if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { | ||
| ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); | ||
| } else { | ||
| ${ev.primitive} = r; | ||
| } | ||
| """ | ||
| // byte and short are casted into int when add, minus, times or divide | ||
| case ByteType | ShortType => | ||
| s""" | ||
| ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); | ||
| if (r < 0) { | ||
| ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); | ||
| } else { | ||
| ${ev.primitive} = r; | ||
| } | ||
| """ | ||
| case _ => | ||
| s""" | ||
| ${ctx.javaType(dataType)} r = $eval1 % $eval2; | ||
| if (r < 0) { | ||
| ${ev.primitive} = (r + $eval2) % $eval2; | ||
| } else { | ||
| ${ev.primitive} = r; | ||
| } | ||
| """ | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| private def pmod(a: Int, n: Int): Int = { | ||
| val r = a % n | ||
| if (r < 0) {(r + n) % n} else r | ||
| } | ||
|
|
||
| private def pmod(a: Long, n: Long): Long = { | ||
| val r = a % n | ||
| if (r < 0) {(r + n) % n} else r | ||
| } | ||
|
|
||
| private def pmod(a: Byte, n: Byte): Byte = { | ||
| val r = a % n | ||
| if (r < 0) {((r + n) % n).toByte} else r.toByte | ||
| } | ||
|
|
||
| private def pmod(a: Double, n: Double): Double = { | ||
| val r = a % n | ||
| if (r < 0) {(r + n) % n} else r | ||
| } | ||
|
|
||
| private def pmod(a: Short, n: Short): Short = { | ||
| val r = a % n | ||
| if (r < 0) {((r + n) % n).toShort} else r.toShort | ||
| } | ||
|
|
||
| private def pmod(a: Float, n: Float): Float = { | ||
| val r = a % n | ||
| if (r < 0) {(r + n) % n} else r | ||
| } | ||
|
|
||
| private def pmod(a: Decimal, n: Decimal): Decimal = { | ||
| val r = a % n | ||
| if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite | |
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.types.Decimal | ||
|
|
||
|
|
||
| class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { | ||
|
|
||
| /** | ||
|
|
@@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), | ||
| Array(1.toByte, 2.toByte)) | ||
| } | ||
|
|
||
| test("pmod") { | ||
| testNumericDataTypes { convert => | ||
| val left = Literal(convert(7)) | ||
| val right = Literal(convert(3)) | ||
| checkEvaluation(Pmod(left, right), convert(1)) | ||
| checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) | ||
| checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) | ||
| checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 | ||
| } | ||
| checkEvaluation(Pmod(-7, 3), 2) | ||
| checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) | ||
| checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to test all supported types.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| checkEvaluation(Pmod(2L, Long.MaxValue), 2) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we have the trait
ExpectInputTypes, probably we don't need to add the specific rule here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just try, we need this otherwise it wouldn't pass the hive compatibilitysuite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's because implicitCast will not cast DecimalType into DecimalType.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then let's add a general rule in
ImplicitCast, and remove this, sound reasonable?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the casting rule for
pmodis not a general one, just as like the Remaindercase Remainder.... DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))is the special casting rule for the result inHiveTypeCoercion.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's keep the same pattern as the other arithmetic expression for now.