diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index a331a5557b45..1d7a3c735607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends LeafExpression with Nondeterministic { - - protected def seed: Long - +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic { rng = new XORShiftRandom(seed + partitionIndex) } + @transient protected lazy val seed: Long = child match { + case Literal(s, IntegerType) => s.asInstanceOf[Int] + case Literal(s, LongType) => s.asInstanceOf[Long] + case _ => throw new AnalysisException( + s"Input argument to $prettyName must be an integer, long or null literal.") + } + override def nullable: Boolean = false override def dataType: DataType = DoubleType - // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. - override def sql: String = s"$prettyName($seed)" + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ @@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic { 0.9629742951434543 > SELECT _FUNC_(0); 0.8446490682263027 + > SELECT _FUNC_(null); + 0.8446490682263027 """) // scalastyle:on line.size.limit -case class Rand(seed: Long) extends RDG { - override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() +case class Rand(child: Expression) extends RDG { - def this() = this(Utils.random.nextLong()) + def this() = this(Literal(Utils.random.nextLong(), LongType)) - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") - }) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") @@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG { } } +object Rand { + def apply(seed: Long): Rand = Rand(Literal(seed, LongType)) +} + /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG { -0.3254147983080288 > SELECT _FUNC_(0); 1.1164209726833079 + > SELECT _FUNC_(null); + 1.1164209726833079 """) // scalastyle:on line.size.limit -case class Randn(seed: Long) extends RDG { - override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() +case class Randn(child: Expression) extends RDG { - def this() = this(Utils.random.nextLong()) + def this() = this(Literal(Utils.random.nextLong(), LongType)) - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") - }) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") @@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG { final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } } + +object Randn { + def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index b7a0d44fa7e5..752c9d5449ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{IntegerType, LongType} class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) + + checkDoubleEvaluation( + new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001) + checkDoubleEvaluation( + new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001) } test("SPARK-9127 codegen with long seed") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql new file mode 100644 index 000000000000..a1aae7b8759d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql @@ -0,0 +1,17 @@ +-- rand with the seed 0 +SELECT rand(0); +SELECT rand(cast(3 / 7 AS int)); +SELECT rand(NULL); +SELECT rand(cast(NULL AS int)); + +-- rand unsupported data type +SELECT rand(1.0); + +-- randn with the seed 0 +SELECT randn(0L); +SELECT randn(cast(3 / 7 AS long)); +SELECT randn(NULL); +SELECT randn(cast(NULL AS long)); + +-- randn unsupported data type +SELECT rand('1') diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out new file mode 100644 index 000000000000..bca67320fe7b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -0,0 +1,84 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +SELECT rand(0) +-- !query 0 schema +struct +-- !query 0 output +0.8446490682263027 + + +-- !query 1 +SELECT rand(cast(3 / 7 AS int)) +-- !query 1 schema +struct +-- !query 1 output +0.8446490682263027 + + +-- !query 2 +SELECT rand(NULL) +-- !query 2 schema +struct +-- !query 2 output +0.8446490682263027 + + +-- !query 3 +SELECT rand(cast(NULL AS int)) +-- !query 3 schema +struct +-- !query 3 output +0.8446490682263027 + + +-- !query 4 +SELECT rand(1.0) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 + + +-- !query 5 +SELECT randn(0L) +-- !query 5 schema +struct +-- !query 5 output +1.1164209726833079 + + +-- !query 6 +SELECT randn(cast(3 / 7 AS long)) +-- !query 6 schema +struct +-- !query 6 output +1.1164209726833079 + + +-- !query 7 +SELECT randn(NULL) +-- !query 7 schema +struct +-- !query 7 output +1.1164209726833079 + + +-- !query 8 +SELECT randn(cast(NULL AS long)) +-- !query 8 schema +struct +-- !query 8 output +1.1164209726833079 + + +-- !query 9 +SELECT rand('1') +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7