Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

Expand All @@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
abstract class RDG extends LeafExpression with Nondeterministic {

protected def seed: Long

abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
Expand All @@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic {
rng = new XORShiftRandom(seed + partitionIndex)
}

@transient protected lazy val seed: Long = child match {
case Literal(s, IntegerType) => s.asInstanceOf[Int]
case Literal(s, LongType) => s.asInstanceOf[Long]
case _ => throw new AnalysisException(
s"Input argument to $prettyName must be an integer, long or null literal.")
}

override def nullable: Boolean = false

override def dataType: DataType = DoubleType

// NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
override def sql: String = s"$prettyName($seed)"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
Expand All @@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic {
0.9629742951434543
> SELECT _FUNC_(0);
0.8446490682263027
> SELECT _FUNC_(null);
0.8446490682263027
""")
// scalastyle:on line.size.limit
case class Rand(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
case class Rand(child: Expression) extends RDG {

def this() = this(Utils.random.nextLong())
def this() = this(Literal(Utils.random.nextLong(), LongType))

def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
Expand All @@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG {
}
}

object Rand {
def apply(seed: Long): Rand = Rand(Literal(seed, LongType))
}

/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
Expand All @@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG {
-0.3254147983080288
> SELECT _FUNC_(0);
1.1164209726833079
> SELECT _FUNC_(null);
1.1164209726833079
""")
// scalastyle:on line.size.limit
case class Randn(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
case class Randn(child: Expression) extends RDG {

def this() = this(Utils.random.nextLong())
def this() = this(Literal(Utils.random.nextLong(), LongType))

def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
})
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
Expand All @@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG {
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
}
}

object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions
import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

test("random") {
checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

checkDoubleEvaluation(
new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
checkDoubleEvaluation(
new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
}

test("SPARK-9127 codegen with long seed") {
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/random.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- rand with the seed 0
SELECT rand(0);
SELECT rand(cast(3 / 7 AS int));
SELECT rand(NULL);
SELECT rand(cast(NULL AS int));

-- rand unsupported data type
SELECT rand(1.0);

-- randn with the seed 0
SELECT randn(0L);
SELECT randn(cast(3 / 7 AS long));
SELECT randn(NULL);
SELECT randn(cast(NULL AS long));

-- randn unsupported data type
SELECT rand('1')
84 changes: 84 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/random.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 10


-- !query 0
SELECT rand(0)
-- !query 0 schema
struct<rand(0):double>
-- !query 0 output
0.8446490682263027


-- !query 1
SELECT rand(cast(3 / 7 AS int))
-- !query 1 schema
struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double>
-- !query 1 output
0.8446490682263027


-- !query 2
SELECT rand(NULL)
-- !query 2 schema
struct<rand(CAST(NULL AS INT)):double>
-- !query 2 output
0.8446490682263027


-- !query 3
SELECT rand(cast(NULL AS int))
-- !query 3 schema
struct<rand(CAST(NULL AS INT)):double>
-- !query 3 output
0.8446490682263027


-- !query 4
SELECT rand(1.0)
-- !query 4 schema
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7


-- !query 5
SELECT randn(0L)
-- !query 5 schema
struct<randn(0):double>
-- !query 5 output
1.1164209726833079


-- !query 6
SELECT randn(cast(3 / 7 AS long))
-- !query 6 schema
struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double>
-- !query 6 output
1.1164209726833079


-- !query 7
SELECT randn(NULL)
-- !query 7 schema
struct<randn(CAST(NULL AS INT)):double>
-- !query 7 output
1.1164209726833079


-- !query 8
SELECT randn(cast(NULL AS long))
-- !query 8 schema
struct<randn(CAST(NULL AS BIGINT)):double>
-- !query 8 output
1.1164209726833079


-- !query 9
SELECT rand('1')
-- !query 9 schema
struct<>
-- !query 9 output
org.apache.spark.sql.AnalysisException
cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7