Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ object FunctionRegistry {
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Pow]("power"),
expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ object HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we have the trait ExpectInputTypes, probably we don't need to add the specific rule here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just try, we need this otherwise it wouldn't pass the hive compatibilitysuite

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because implicitCast will not cast DecimalType into DecimalType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's add a general rule in ImplicitCast, and remove this, sound reasonable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the casting rule for pmod is not a general one, just as like the Remainder case Remainder.... DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) is the special casting rule for the result in HiveTypeCoercion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's keep the same pattern as the other arithmetic expression for now.

Cast(
Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "min"
override def prettyName: String = symbol
}

case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should override toString to be s"pmod($left, $right)"


override def toString: String = s"pmod($left, $right)"

override def symbol: String = "pmod"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this symbol is not needed as you have override toString, however we have to implement it here as it's defined abstract in BinaryOperator. We may need to fix it in the future.


protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "pmod")

override def inputType: AbstractDataType = NumericType

protected override def nullSafeEval(left: Any, right: Any) =
dataType match {
case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short])
case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float])
case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double])
case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about follow Remainder to use Integral.rem to do the computation? That will save a lot of code here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea , but seems like the clac result would be a little bit different with the expected result. pmod(7.2, 4.1), actual: 3.099999999999999, expected: 3.1000000000000005 . I'm keen on keeping it this way if not obvious downside.

  private lazy val integral = dataType match {
    case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
    case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
    case i: DecimalType => i.asIntegral.asInstanceOf[Integral[Any]]
  }
  protected override def evalInternal(a: Any, n: Any) = 
    integral.rem(integral.plus(integral.rem(a, n), n), n)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok maybe you can send another PR to fix Divide in this way too, as it's result is different from hive's.

}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
dataType match {
case dt: DecimalType =>
val decimalAdd = "$plus"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just check the code, you probably need to override the decimalMethod, which is defined in BinaryArithmetic, we'd better follow the same pattern of the other ArithmeticExpressions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, as we talked offline, this is a exceptional case in ArithmeticExpression, we will take both $plus and remainder for the method decimalMethod.

UPDATE: let's ignore the decimalMethod.

s"""
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi there is a bug here -- if we use pmod twice, this will fail codegen because r is not unique.

if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
} else {
${ev.primitive} = r;
}
"""
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
if (r < 0) {
${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
} else {
${ev.primitive} = r;
}
"""
case _ =>
s"""
${ctx.javaType(dataType)} r = $eval1 % $eval2;
if (r < 0) {
${ev.primitive} = (r + $eval2) % $eval2;
} else {
${ev.primitive} = r;
}
"""
}
})
}

private def pmod(a: Int, n: Int): Int = {
val r = a % n
if (r < 0) {(r + n) % n} else r
}

private def pmod(a: Long, n: Long): Long = {
val r = a % n
if (r < 0) {(r + n) % n} else r
}

private def pmod(a: Byte, n: Byte): Byte = {
val r = a % n
if (r < 0) {((r + n) % n).toByte} else r.toByte
}

private def pmod(a: Double, n: Double): Double = {
val r = a % n
if (r < 0) {(r + n) % n} else r
}

private def pmod(a: Short, n: Short): Short = {
val r = a % n
if (r < 0) {((r + n) % n).toShort} else r.toShort
}

private def pmod(a: Float, n: Float): Float = {
val r = a % n
if (r < 0) {(r + n) % n} else r
}

private def pmod(a: Decimal, n: Decimal): Decimal = {
val r = a % n
if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.Decimal


class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

/**
Expand Down Expand Up @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
Array(1.toByte, 2.toByte))
}

test("pmod") {
testNumericDataTypes { convert =>
val left = Literal(convert(7))
val right = Literal(convert(3))
checkEvaluation(Pmod(left, right), convert(1))
checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
}
checkEvaluation(Pmod(-7, 3), 2)
checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test all supported types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testNumericDataTypes has covered all of the numeric types and the other cover by the remain tests.

checkEvaluation(Pmod(2L, Long.MaxValue), 2)
}
}
17 changes: 17 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,23 @@ object functions {
*/
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))

/**
* Returns the positive value of dividend mod divisor.
*
* @group math_funcs
* @since 1.5.0
*/
def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr)

/**
* Returns the positive value of dividend mod divisor.
*
* @group math_funcs
* @since 1.5.0
*/
def pmod(dividendColName: String, divisorColName: String): Column =
pmod(Column(dividendColName), Column(divisorColName))

/**
* Returns the double value that is closest in value to the argument and
* is equal to a mathematical integer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest {
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}

test("pmod") {
val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
checkAnswer(
intData.select(pmod('a, 'b)),
Seq(Row(1), Row(2))
)
checkAnswer(
intData.select(pmod('a, lit(3))),
Seq(Row(1), Row(2))
)
checkAnswer(
intData.select(pmod(lit(-7), 'b)),
Seq(Row(2), Row(2))
)
checkAnswer(
intData.selectExpr("pmod(a, b)"),
Seq(Row(1), Row(2))
)
checkAnswer(
intData.selectExpr("pmod(a, 3)"),
Seq(Row(1), Row(2))
)
checkAnswer(
intData.selectExpr("pmod(-7, b)"),
Seq(Row(2), Row(2))
)
val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
checkAnswer(
doubleData.select(pmod('a, 'b)),
Seq(Row(3.1000000000000005)) // same as hive
)
checkAnswer(
doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
Seq(Row(2))
)
}
}